clock_curve_math/bigint/
backend_bigint.rs

1//! BigInt backend using clock-bigint.
2
3use core::cmp::Ordering;
4
5use clock_bigint::U256;
6
7use crate::ct;
8
9/// BigInt wrapper around clock-bigint's U256.
10#[derive(Clone, Copy, Debug)]
11pub struct BigInt {
12    inner: U256,
13}
14
15impl BigInt {
16    /// Create a BigInt from u64 limbs (little-endian).
17    pub fn from_limbs(limbs: &[u64; 4]) -> Self {
18        Self {
19            inner: U256::from_limbs(limbs).unwrap(),
20        }
21    }
22
23    /// Create a BigInt from bytes (little-endian).
24    pub fn from_bytes(bytes: &[u8; 32]) -> Self {
25        Self {
26            inner: U256::from_limbs(&[
27                u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
28                u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
29                u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
30                u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
31            ])
32            .unwrap(),
33        }
34    }
35
36    /// Create a BigInt from a u64 value.
37    pub fn from_u64(value: u64) -> Self {
38        Self {
39            inner: U256::from_u64(value),
40        }
41    }
42
43    /// Convert to little-endian bytes.
44    pub fn to_bytes(&self) -> [u8; 32] {
45        let limbs = self.inner.as_limbs();
46        let mut bytes = [0u8; 32];
47        for (i, &limb) in limbs.iter().enumerate() {
48            bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes());
49        }
50        bytes
51    }
52
53    /// Get the limbs (little-endian).
54    pub fn limbs(&self) -> [u64; 4] {
55        let slice = self.inner.as_limbs();
56        [slice[0], slice[1], slice[2], slice[3]]
57    }
58
59    /// Constant-time addition.
60    pub fn add(&self, rhs: &Self) -> Self {
61        let a_limbs = self.limbs();
62        let b_limbs = rhs.limbs();
63        let mut result_limbs = [0u64; 4];
64        let mut carry = 0u64;
65
66        for i in 0..4 {
67            let (sum, carry1) = a_limbs[i].overflowing_add(b_limbs[i]);
68            let (sum, carry2) = sum.overflowing_add(carry);
69            result_limbs[i] = sum;
70            carry = (carry1 as u64) + (carry2 as u64);
71        }
72
73        Self::from_limbs(&result_limbs)
74    }
75
76    /// Constant-time subtraction.
77    pub fn sub(&self, rhs: &Self) -> Self {
78        let a_limbs = self.limbs();
79        let b_limbs = rhs.limbs();
80        let mut result_limbs = [0u64; 4];
81        let mut borrow = 0u64;
82
83        for i in 0..4 {
84            let (diff, borrow1) = a_limbs[i].overflowing_sub(b_limbs[i]);
85            let (diff, borrow2) = diff.overflowing_sub(borrow);
86            result_limbs[i] = diff;
87            borrow = (borrow1 as u64) + (borrow2 as u64);
88        }
89
90        Self::from_limbs(&result_limbs)
91    }
92
93    /// Constant-time multiplication with proper overflow handling.
94    ///
95    /// Returns the full 256-bit product. For modular arithmetic,
96    /// the caller is responsible for reduction if needed.
97    pub fn mul(&self, rhs: &Self) -> Self {
98        let a_limbs = self.limbs();
99        let b_limbs = rhs.limbs();
100        let mut result = [0u64; 8]; // 8 limbs for 256x256 = 512 bit result
101
102        // Optimized schoolbook multiplication using u128 to avoid overflow
103        // Process in a way that maximizes ILP (Instruction Level Parallelism)
104        for i in 0..4 {
105            let mut carry = 0u128;
106            for j in 0..4 {
107                let prod = a_limbs[i] as u128 * b_limbs[j] as u128;
108                let sum = prod + carry + result[i + j] as u128;
109                result[i + j] = sum as u64;
110                carry = sum >> 64;
111            }
112            // Propagate remaining carry
113            let mut k = i + 4;
114            while carry > 0 && k < 8 {
115                let sum = carry + result[k] as u128;
116                result[k] = sum as u64;
117                carry = sum >> 64;
118                k += 1;
119            }
120        }
121
122        // For 256-bit modular arithmetic, we typically want the result mod 2^256
123        // This is equivalent to taking the low 4 limbs
124        Self::from_limbs(&[result[0], result[1], result[2], result[3]])
125    }
126
127    /// Wide multiplication returning the full 512-bit result as (low, high) BigInts.
128    pub fn mul_wide(&self, rhs: &Self) -> (Self, Self) {
129        let a_limbs = self.limbs();
130        let b_limbs = rhs.limbs();
131        let mut result = [0u64; 8]; // 8 limbs for 256x256 = 512 bit result
132
133        // Simple schoolbook multiplication with u128 to avoid overflow
134        for i in 0..4 {
135            let mut carry = 0u128;
136            for j in 0..4 {
137                let prod = a_limbs[i] as u128 * b_limbs[j] as u128;
138                let sum = prod + carry + result[i + j] as u128;
139                result[i + j] = sum as u64;
140                carry = sum >> 64;
141            }
142            // Propagate remaining carry
143            let mut k = i + 4;
144            while carry > 0 && k < 8 {
145                let sum = carry + result[k] as u128;
146                result[k] = sum as u64;
147                carry = sum >> 64;
148                k += 1;
149            }
150        }
151
152        let low = Self::from_limbs(&[result[0], result[1], result[2], result[3]]);
153        let high = Self::from_limbs(&[result[4], result[5], result[6], result[7]]);
154        (low, high)
155    }
156
157    /// Division with remainder using binary long division.
158    ///
159    /// Returns (quotient, remainder) where self = quotient * divisor + remainder.
160    /// This implementation uses binary long division which is much more efficient
161    /// than repeated subtraction.
162    pub fn div_rem(&self, divisor: &Self) -> (Self, Self) {
163        // Handle division by zero
164        if divisor.is_zero() {
165            panic!("Division by zero");
166        }
167
168        // Handle simple cases
169        if self.is_zero() {
170            return (BigInt::from_u64(0), BigInt::from_u64(0));
171        }
172
173        let cmp = self.cmp(divisor);
174        if cmp == core::cmp::Ordering::Less {
175            return (BigInt::from_u64(0), *self);
176        }
177        if cmp == core::cmp::Ordering::Equal {
178            return (BigInt::from_u64(1), BigInt::from_u64(0));
179        }
180
181        // Use binary search to find the quotient more efficiently
182        let mut quotient = BigInt::from_u64(0);
183        let mut remainder = *self;
184
185        // Binary search for quotient
186        // Upper bound is self.bit_length() - divisor.bit_length() + 1 bits
187        let mut low = BigInt::from_u64(0);
188        let mut high = BigInt::from_u64(1).shl((self.bit_length() - divisor.bit_length() + 1).min(256));
189
190        while low.cmp(&high) == core::cmp::Ordering::Less {
191            let mid = low.add(&high).shr(1);
192            let product = divisor.mul(&mid);
193
194            match product.cmp(&remainder) {
195                core::cmp::Ordering::Less | core::cmp::Ordering::Equal => {
196                    quotient = mid;
197                    low = mid.add(&BigInt::from_u64(1));
198                }
199                core::cmp::Ordering::Greater => {
200                    high = mid;
201                }
202            }
203        }
204
205        // Compute remainder
206        let product = divisor.mul(&quotient);
207        remainder = remainder.sub(&product);
208
209        (quotient, remainder)
210    }
211
212    /// Constant-time left shift.
213    pub fn shl(&self, bits: u32) -> Self {
214        // For now, use a simple implementation - clock-bigint may have shift functions
215        // This is a placeholder
216        let limbs = self.limbs();
217        let mut result_limbs = [0u64; 4];
218
219        let limb_shift = (bits / 64) as usize;
220        let bit_shift = bits % 64;
221
222        if limb_shift >= 4 {
223            // Shift by more than 256 bits results in zero
224            return Self::from_u64(0);
225        }
226
227        for i in 0..4 {
228            if i + limb_shift < 4 {
229                result_limbs[i + limb_shift] = limbs[i] << bit_shift;
230                if bit_shift > 0 && i + limb_shift + 1 < 4 {
231                    result_limbs[i + limb_shift + 1] |= limbs[i] >> (64 - bit_shift);
232                }
233            }
234        }
235
236        Self::from_limbs(&result_limbs)
237    }
238
239    /// Constant-time right shift.
240    pub fn shr(&self, bits: u32) -> Self {
241        // Similar to left shift but in reverse
242        let limbs = self.limbs();
243        let mut result_limbs = [0u64; 4];
244
245        let limb_shift = (bits / 64) as usize;
246        let bit_shift = bits % 64;
247
248        if limb_shift >= 4 {
249            return Self::from_u64(0);
250        }
251
252        for i in 0..4 {
253            if i >= limb_shift {
254                result_limbs[i - limb_shift] = limbs[i] >> bit_shift;
255                if bit_shift > 0 && i > limb_shift {
256                    result_limbs[i - limb_shift - 1] |= limbs[i] << (64 - bit_shift);
257                }
258            }
259        }
260
261        Self::from_limbs(&result_limbs)
262    }
263
264    /// Constant-time comparison.
265    pub fn cmp(&self, rhs: &Self) -> Ordering {
266        // Compare from most significant limb to least significant
267        let self_limbs = self.limbs();
268        let rhs_limbs = rhs.limbs();
269
270        for i in (0..4).rev() {
271            if self_limbs[i] > rhs_limbs[i] {
272                return Ordering::Greater;
273            } else if self_limbs[i] < rhs_limbs[i] {
274                return Ordering::Less;
275            }
276        }
277
278        Ordering::Equal
279    }
280
281    /// Constant-time zero check.
282    pub fn is_zero(&self) -> bool {
283        let limbs = self.limbs();
284        let mut result = 1u64;
285        for limb in &limbs {
286            result &= ct::ct_is_zero(*limb);
287        }
288        result == 1
289    }
290
291    /// Get the bit at the specified position (constant-time).
292    ///
293    /// Returns 1 if the bit is set, 0 otherwise.
294    /// Bit positions start from 0 (least significant bit).
295    pub fn get_bit(&self, bit_position: u32) -> u64 {
296        let limb_index = (bit_position / 64) as usize;
297        let bit_index = bit_position % 64;
298
299        if limb_index >= 4 {
300            return 0;
301        }
302
303        let limbs = self.limbs();
304        let limb = limbs[limb_index];
305        ((limb >> bit_index) & 1) as u64
306    }
307
308    /// Get the number of bits needed to represent this BigInt.
309    ///
310    /// Returns the position of the highest set bit plus one.
311    /// Returns 0 for zero.
312    pub fn bit_length(&self) -> u32 {
313        let limbs = self.limbs();
314
315        // Find the highest non-zero limb
316        for i in (0..4).rev() {
317            if limbs[i] != 0 {
318                // Find the highest bit in this limb
319                let mut bit = 63;
320                while bit > 0 {
321                    if (limbs[i] & (1u64 << bit)) != 0 {
322                        return (i as u32 * 64) + bit + 1;
323                    }
324                    bit -= 1;
325                }
326                // If we reach here, it's bit 0
327                return (i as u32 * 64) + 1;
328            }
329        }
330
331        0 // Zero has no bits set
332    }
333}
334
335impl PartialEq for BigInt {
336    fn eq(&self, other: &Self) -> bool {
337        self.limbs() == other.limbs()
338    }
339}
340
341impl Eq for BigInt {}
342
343impl Default for BigInt {
344    fn default() -> Self {
345        Self::from_u64(0)
346    }
347}