1use ark_ff::PrimeField;
2use ark_relations::r1cs::SynthesisError;
3use ark_std::{ops::Shr, ops::ShrAssign};
4
5use crate::boolean::Boolean;
6
7use super::{PrimUInt, UInt};
8
9impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
10 fn _shr_u128(&self, other: u128) -> Result<Self, SynthesisError> {
11 if other < N as u128 {
12 let mut bits = [Boolean::FALSE; N];
13 for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) {
14 *a = b.clone();
15 }
16
17 let value = self.value.and_then(|a| Some(a >> other));
18 Ok(Self { bits, value })
19 } else {
20 panic!("attempt to shift right with overflow")
21 }
22 }
23}
24
25impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for UInt<N, T, F> {
26 type Output = Self;
27
28 #[tracing::instrument(target = "r1cs", skip(self, other))]
51 fn shr(self, other: T2) -> Self::Output {
52 self._shr_u128(other.into()).unwrap()
53 }
54}
55
56impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for &'a UInt<N, T, F> {
57 type Output = UInt<N, T, F>;
58
59 #[tracing::instrument(target = "r1cs", skip(self, other))]
60 fn shr(self, other: T2) -> Self::Output {
61 self._shr_u128(other.into()).unwrap()
62 }
63}
64
65impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShrAssign<T2> for UInt<N, T, F> {
66 #[tracing::instrument(target = "r1cs", skip(self, other))]
90 fn shr_assign(&mut self, other: T2) {
91 let result = self._shr_u128(other.into()).unwrap();
92 *self = result;
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::{
100 alloc::{AllocVar, AllocationMode},
101 prelude::EqGadget,
102 uint::test_utils::{run_binary_exhaustive_native_only, run_binary_random_native_only},
103 R1CSVar,
104 };
105 use ark_ff::PrimeField;
106 use ark_test_curves::bls12_381::Fr;
107
108 fn uint_shr<T: PrimUInt, const N: usize, F: PrimeField>(
109 a: UInt<N, T, F>,
110 b: T,
111 ) -> Result<(), SynthesisError> {
112 let cs = a.cs();
113 let b = b.into() % (N as u128);
114 let computed = &a >> b;
115 let expected_mode = if a.is_constant() {
116 AllocationMode::Constant
117 } else {
118 AllocationMode::Witness
119 };
120 let expected =
121 UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? >> b), expected_mode)?;
122 assert_eq!(expected.value(), computed.value());
123 expected.enforce_equal(&computed)?;
124 if !a.is_constant() {
125 assert!(cs.is_satisfied().unwrap());
126 }
127 Ok(())
128 }
129
130 #[test]
131 fn u8_shr() {
132 run_binary_exhaustive_native_only(uint_shr::<u8, 8, Fr>).unwrap()
133 }
134
135 #[test]
136 fn u16_shr() {
137 run_binary_random_native_only::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
138 }
139
140 #[test]
141 fn u32_shr() {
142 run_binary_random_native_only::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
143 }
144
145 #[test]
146 fn u64_shr() {
147 run_binary_random_native_only::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
148 }
149
150 #[test]
151 fn u128_shr() {
152 run_binary_random_native_only::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
153 }
154}