clock_curve_math/bigint/
backend_bigint.rs1use core::cmp::Ordering;
4
5use clock_bigint::U256;
6
7use crate::ct;
8
9#[derive(Clone, Copy, Debug)]
11pub struct BigInt {
12 inner: U256,
13}
14
15impl BigInt {
16 pub fn from_limbs(limbs: &[u64; 4]) -> Self {
18 Self {
19 inner: U256::from_limbs(limbs).unwrap(),
20 }
21 }
22
23 pub fn from_bytes(bytes: &[u8; 32]) -> Self {
25 Self {
26 inner: U256::from_limbs(&[
27 u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
28 u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
29 u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
30 u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
31 ])
32 .unwrap(),
33 }
34 }
35
36 pub fn from_u64(value: u64) -> Self {
38 Self {
39 inner: U256::from_u64(value),
40 }
41 }
42
43 pub fn to_bytes(&self) -> [u8; 32] {
45 let limbs = self.inner.as_limbs();
46 let mut bytes = [0u8; 32];
47 for (i, &limb) in limbs.iter().enumerate() {
48 bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes());
49 }
50 bytes
51 }
52
53 pub fn limbs(&self) -> [u64; 4] {
55 let slice = self.inner.as_limbs();
56 [slice[0], slice[1], slice[2], slice[3]]
57 }
58
59 pub fn add(&self, rhs: &Self) -> Self {
61 let a_limbs = self.limbs();
62 let b_limbs = rhs.limbs();
63 let mut result_limbs = [0u64; 4];
64 let mut carry = 0u64;
65
66 for i in 0..4 {
67 let (sum, carry1) = a_limbs[i].overflowing_add(b_limbs[i]);
68 let (sum, carry2) = sum.overflowing_add(carry);
69 result_limbs[i] = sum;
70 carry = (carry1 as u64) + (carry2 as u64);
71 }
72
73 Self::from_limbs(&result_limbs)
74 }
75
76 pub fn sub(&self, rhs: &Self) -> Self {
78 let a_limbs = self.limbs();
79 let b_limbs = rhs.limbs();
80 let mut result_limbs = [0u64; 4];
81 let mut borrow = 0u64;
82
83 for i in 0..4 {
84 let (diff, borrow1) = a_limbs[i].overflowing_sub(b_limbs[i]);
85 let (diff, borrow2) = diff.overflowing_sub(borrow);
86 result_limbs[i] = diff;
87 borrow = (borrow1 as u64) + (borrow2 as u64);
88 }
89
90 Self::from_limbs(&result_limbs)
91 }
92
93 pub fn mul(&self, rhs: &Self) -> Self {
98 let a_limbs = self.limbs();
99 let b_limbs = rhs.limbs();
100 let mut result = [0u64; 8]; for i in 0..4 {
105 let mut carry = 0u128;
106 for j in 0..4 {
107 let prod = a_limbs[i] as u128 * b_limbs[j] as u128;
108 let sum = prod + carry + result[i + j] as u128;
109 result[i + j] = sum as u64;
110 carry = sum >> 64;
111 }
112 let mut k = i + 4;
114 while carry > 0 && k < 8 {
115 let sum = carry + result[k] as u128;
116 result[k] = sum as u64;
117 carry = sum >> 64;
118 k += 1;
119 }
120 }
121
122 Self::from_limbs(&[result[0], result[1], result[2], result[3]])
125 }
126
127 pub fn mul_wide(&self, rhs: &Self) -> (Self, Self) {
129 let a_limbs = self.limbs();
130 let b_limbs = rhs.limbs();
131 let mut result = [0u64; 8]; for i in 0..4 {
135 let mut carry = 0u128;
136 for j in 0..4 {
137 let prod = a_limbs[i] as u128 * b_limbs[j] as u128;
138 let sum = prod + carry + result[i + j] as u128;
139 result[i + j] = sum as u64;
140 carry = sum >> 64;
141 }
142 let mut k = i + 4;
144 while carry > 0 && k < 8 {
145 let sum = carry + result[k] as u128;
146 result[k] = sum as u64;
147 carry = sum >> 64;
148 k += 1;
149 }
150 }
151
152 let low = Self::from_limbs(&[result[0], result[1], result[2], result[3]]);
153 let high = Self::from_limbs(&[result[4], result[5], result[6], result[7]]);
154 (low, high)
155 }
156
157 pub fn div_rem(&self, divisor: &Self) -> (Self, Self) {
163 if divisor.is_zero() {
165 panic!("Division by zero");
166 }
167
168 if self.is_zero() {
170 return (BigInt::from_u64(0), BigInt::from_u64(0));
171 }
172
173 let cmp = self.cmp(divisor);
174 if cmp == core::cmp::Ordering::Less {
175 return (BigInt::from_u64(0), *self);
176 }
177 if cmp == core::cmp::Ordering::Equal {
178 return (BigInt::from_u64(1), BigInt::from_u64(0));
179 }
180
181 let mut quotient = BigInt::from_u64(0);
183 let mut remainder = *self;
184
185 let mut divisor_shifted = *divisor;
187 let mut shift_count = 0u32;
188
189 while divisor_shifted.cmp(&remainder) != core::cmp::Ordering::Greater {
191 divisor_shifted = divisor_shifted.shl(1);
192 shift_count += 1;
193 }
194
195 divisor_shifted = divisor_shifted.shr(1);
197 shift_count -= 1;
198
199 for _ in 0..=shift_count {
201 if remainder.cmp(&divisor_shifted) != core::cmp::Ordering::Less {
202 remainder = remainder.sub(&divisor_shifted);
203 let bit_mask = BigInt::from_u64(1).shl(shift_count);
205 quotient = quotient.add(&bit_mask);
206 }
207 divisor_shifted = divisor_shifted.shr(1);
208 if shift_count > 0 {
209 shift_count -= 1;
210 }
211 }
212
213 (quotient, remainder)
214 }
215
216 pub fn shl(&self, bits: u32) -> Self {
218 let limbs = self.limbs();
221 let mut result_limbs = [0u64; 4];
222
223 let limb_shift = (bits / 64) as usize;
224 let bit_shift = bits % 64;
225
226 if limb_shift >= 4 {
227 return Self::from_u64(0);
229 }
230
231 for i in 0..4 {
232 if i + limb_shift < 4 {
233 result_limbs[i + limb_shift] = limbs[i] << bit_shift;
234 if bit_shift > 0 && i + limb_shift + 1 < 4 {
235 result_limbs[i + limb_shift + 1] |= limbs[i] >> (64 - bit_shift);
236 }
237 }
238 }
239
240 Self::from_limbs(&result_limbs)
241 }
242
243 pub fn shr(&self, bits: u32) -> Self {
245 let limbs = self.limbs();
247 let mut result_limbs = [0u64; 4];
248
249 let limb_shift = (bits / 64) as usize;
250 let bit_shift = bits % 64;
251
252 if limb_shift >= 4 {
253 return Self::from_u64(0);
254 }
255
256 for i in 0..4 {
257 if i >= limb_shift {
258 result_limbs[i - limb_shift] = limbs[i] >> bit_shift;
259 if bit_shift > 0 && i > limb_shift {
260 result_limbs[i - limb_shift - 1] |= limbs[i] << (64 - bit_shift);
261 }
262 }
263 }
264
265 Self::from_limbs(&result_limbs)
266 }
267
268 pub fn cmp(&self, rhs: &Self) -> Ordering {
270 let self_limbs = self.limbs();
272 let rhs_limbs = rhs.limbs();
273
274 for i in (0..4).rev() {
275 if self_limbs[i] > rhs_limbs[i] {
276 return Ordering::Greater;
277 } else if self_limbs[i] < rhs_limbs[i] {
278 return Ordering::Less;
279 }
280 }
281
282 Ordering::Equal
283 }
284
285 pub fn is_zero(&self) -> bool {
287 let limbs = self.limbs();
288 let mut result = 1u64;
289 for limb in &limbs {
290 result &= ct::ct_is_zero(*limb);
291 }
292 result == 1
293 }
294}
295
296impl PartialEq for BigInt {
297 fn eq(&self, other: &Self) -> bool {
298 self.limbs() == other.limbs()
299 }
300}
301
302impl Eq for BigInt {}
303
304impl Default for BigInt {
305 fn default() -> Self {
306 Self::from_u64(0)
307 }
308}