clock_curve_math/bigint/
backend_bigint.rs

1//! BigInt backend using clock-bigint.
2
3use core::cmp::Ordering;
4
5use clock_bigint::U256;
6
7use crate::ct;
8
9/// BigInt wrapper around clock-bigint's U256.
10#[derive(Clone, Copy, Debug)]
11pub struct BigInt {
12    inner: U256,
13}
14
15impl BigInt {
16    /// Create a BigInt from u64 limbs (little-endian).
17    pub fn from_limbs(limbs: &[u64; 4]) -> Self {
18        Self {
19            inner: U256::from_limbs(limbs).unwrap(),
20        }
21    }
22
23    /// Create a BigInt from bytes (little-endian).
24    pub fn from_bytes(bytes: &[u8; 32]) -> Self {
25        Self {
26            inner: U256::from_limbs(&[
27                u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
28                u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
29                u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
30                u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
31            ])
32            .unwrap(),
33        }
34    }
35
36    /// Create a BigInt from a u64 value.
37    pub fn from_u64(value: u64) -> Self {
38        Self {
39            inner: U256::from_u64(value),
40        }
41    }
42
43    /// Convert to little-endian bytes.
44    pub fn to_bytes(&self) -> [u8; 32] {
45        let limbs = self.inner.as_limbs();
46        let mut bytes = [0u8; 32];
47        for (i, &limb) in limbs.iter().enumerate() {
48            bytes[i * 8..(i + 1) * 8].copy_from_slice(&limb.to_le_bytes());
49        }
50        bytes
51    }
52
53    /// Get the limbs (little-endian).
54    pub fn limbs(&self) -> [u64; 4] {
55        let slice = self.inner.as_limbs();
56        [slice[0], slice[1], slice[2], slice[3]]
57    }
58
59    /// Constant-time addition.
60    pub fn add(&self, rhs: &Self) -> Self {
61        let a_limbs = self.limbs();
62        let b_limbs = rhs.limbs();
63        let mut result_limbs = [0u64; 4];
64        let mut carry = 0u64;
65
66        for i in 0..4 {
67            let (sum, carry1) = a_limbs[i].overflowing_add(b_limbs[i]);
68            let (sum, carry2) = sum.overflowing_add(carry);
69            result_limbs[i] = sum;
70            carry = (carry1 as u64) + (carry2 as u64);
71        }
72
73        Self::from_limbs(&result_limbs)
74    }
75
76    /// Constant-time subtraction.
77    pub fn sub(&self, rhs: &Self) -> Self {
78        let a_limbs = self.limbs();
79        let b_limbs = rhs.limbs();
80        let mut result_limbs = [0u64; 4];
81        let mut borrow = 0u64;
82
83        for i in 0..4 {
84            let (diff, borrow1) = a_limbs[i].overflowing_sub(b_limbs[i]);
85            let (diff, borrow2) = diff.overflowing_sub(borrow);
86            result_limbs[i] = diff;
87            borrow = (borrow1 as u64) + (borrow2 as u64);
88        }
89
90        Self::from_limbs(&result_limbs)
91    }
92
93    /// Constant-time multiplication with proper overflow handling.
94    ///
95    /// Returns the full 256-bit product. For modular arithmetic,
96    /// the caller is responsible for reduction if needed.
97    pub fn mul(&self, rhs: &Self) -> Self {
98        let a_limbs = self.limbs();
99        let b_limbs = rhs.limbs();
100        let mut result = [0u64; 8]; // 8 limbs for 256x256 = 512 bit result
101
102        // Optimized schoolbook multiplication using u128 to avoid overflow
103        // Process in a way that maximizes ILP (Instruction Level Parallelism)
104        for i in 0..4 {
105            let mut carry = 0u128;
106            for j in 0..4 {
107                let prod = a_limbs[i] as u128 * b_limbs[j] as u128;
108                let sum = prod + carry + result[i + j] as u128;
109                result[i + j] = sum as u64;
110                carry = sum >> 64;
111            }
112            // Propagate remaining carry
113            let mut k = i + 4;
114            while carry > 0 && k < 8 {
115                let sum = carry + result[k] as u128;
116                result[k] = sum as u64;
117                carry = sum >> 64;
118                k += 1;
119            }
120        }
121
122        // For 256-bit modular arithmetic, we typically want the result mod 2^256
123        // This is equivalent to taking the low 4 limbs
124        Self::from_limbs(&[result[0], result[1], result[2], result[3]])
125    }
126
127    /// Wide multiplication returning the full 512-bit result as (low, high) BigInts.
128    pub fn mul_wide(&self, rhs: &Self) -> (Self, Self) {
129        let a_limbs = self.limbs();
130        let b_limbs = rhs.limbs();
131        let mut result = [0u64; 8]; // 8 limbs for 256x256 = 512 bit result
132
133        // Simple schoolbook multiplication with u128 to avoid overflow
134        for i in 0..4 {
135            let mut carry = 0u128;
136            for j in 0..4 {
137                let prod = a_limbs[i] as u128 * b_limbs[j] as u128;
138                let sum = prod + carry + result[i + j] as u128;
139                result[i + j] = sum as u64;
140                carry = sum >> 64;
141            }
142            // Propagate remaining carry
143            let mut k = i + 4;
144            while carry > 0 && k < 8 {
145                let sum = carry + result[k] as u128;
146                result[k] = sum as u64;
147                carry = sum >> 64;
148                k += 1;
149            }
150        }
151
152        let low = Self::from_limbs(&[result[0], result[1], result[2], result[3]]);
153        let high = Self::from_limbs(&[result[4], result[5], result[6], result[7]]);
154        (low, high)
155    }
156
157    /// Division with remainder using binary long division.
158    ///
159    /// Returns (quotient, remainder) where self = quotient * divisor + remainder.
160    /// This implementation uses binary long division which is much more efficient
161    /// than repeated subtraction.
162    pub fn div_rem(&self, divisor: &Self) -> (Self, Self) {
163        // Handle division by zero
164        if divisor.is_zero() {
165            panic!("Division by zero");
166        }
167
168        // Handle simple cases
169        if self.is_zero() {
170            return (BigInt::from_u64(0), BigInt::from_u64(0));
171        }
172
173        let cmp = self.cmp(divisor);
174        if cmp == core::cmp::Ordering::Less {
175            return (BigInt::from_u64(0), *self);
176        }
177        if cmp == core::cmp::Ordering::Equal {
178            return (BigInt::from_u64(1), BigInt::from_u64(0));
179        }
180
181        // Binary long division
182        let mut quotient = BigInt::from_u64(0);
183        let mut remainder = *self;
184
185        // Find the highest bit set in the divisor
186        let mut divisor_shifted = *divisor;
187        let mut shift_count = 0u32;
188
189        // Shift divisor left until it's greater than remainder
190        while divisor_shifted.cmp(&remainder) != core::cmp::Ordering::Greater {
191            divisor_shifted = divisor_shifted.shl(1);
192            shift_count += 1;
193        }
194
195        // Shift back one position
196        divisor_shifted = divisor_shifted.shr(1);
197        shift_count -= 1;
198
199        // Perform the division
200        for _ in 0..=shift_count {
201            if remainder.cmp(&divisor_shifted) != core::cmp::Ordering::Less {
202                remainder = remainder.sub(&divisor_shifted);
203                // Set the corresponding bit in quotient
204                let bit_mask = BigInt::from_u64(1).shl(shift_count);
205                quotient = quotient.add(&bit_mask);
206            }
207            divisor_shifted = divisor_shifted.shr(1);
208            if shift_count > 0 {
209                shift_count -= 1;
210            }
211        }
212
213        (quotient, remainder)
214    }
215
216    /// Constant-time left shift.
217    pub fn shl(&self, bits: u32) -> Self {
218        // For now, use a simple implementation - clock-bigint may have shift functions
219        // This is a placeholder
220        let limbs = self.limbs();
221        let mut result_limbs = [0u64; 4];
222
223        let limb_shift = (bits / 64) as usize;
224        let bit_shift = bits % 64;
225
226        if limb_shift >= 4 {
227            // Shift by more than 256 bits results in zero
228            return Self::from_u64(0);
229        }
230
231        for i in 0..4 {
232            if i + limb_shift < 4 {
233                result_limbs[i + limb_shift] = limbs[i] << bit_shift;
234                if bit_shift > 0 && i + limb_shift + 1 < 4 {
235                    result_limbs[i + limb_shift + 1] |= limbs[i] >> (64 - bit_shift);
236                }
237            }
238        }
239
240        Self::from_limbs(&result_limbs)
241    }
242
243    /// Constant-time right shift.
244    pub fn shr(&self, bits: u32) -> Self {
245        // Similar to left shift but in reverse
246        let limbs = self.limbs();
247        let mut result_limbs = [0u64; 4];
248
249        let limb_shift = (bits / 64) as usize;
250        let bit_shift = bits % 64;
251
252        if limb_shift >= 4 {
253            return Self::from_u64(0);
254        }
255
256        for i in 0..4 {
257            if i >= limb_shift {
258                result_limbs[i - limb_shift] = limbs[i] >> bit_shift;
259                if bit_shift > 0 && i > limb_shift {
260                    result_limbs[i - limb_shift - 1] |= limbs[i] << (64 - bit_shift);
261                }
262            }
263        }
264
265        Self::from_limbs(&result_limbs)
266    }
267
268    /// Constant-time comparison.
269    pub fn cmp(&self, rhs: &Self) -> Ordering {
270        // Compare from most significant limb to least significant
271        let self_limbs = self.limbs();
272        let rhs_limbs = rhs.limbs();
273
274        for i in (0..4).rev() {
275            if self_limbs[i] > rhs_limbs[i] {
276                return Ordering::Greater;
277            } else if self_limbs[i] < rhs_limbs[i] {
278                return Ordering::Less;
279            }
280        }
281
282        Ordering::Equal
283    }
284
285    /// Constant-time zero check.
286    pub fn is_zero(&self) -> bool {
287        let limbs = self.limbs();
288        let mut result = 1u64;
289        for limb in &limbs {
290            result &= ct::ct_is_zero(*limb);
291        }
292        result == 1
293    }
294}
295
296impl PartialEq for BigInt {
297    fn eq(&self, other: &Self) -> bool {
298        self.limbs() == other.limbs()
299    }
300}
301
302impl Eq for BigInt {}
303
304impl Default for BigInt {
305    fn default() -> Self {
306        Self::from_u64(0)
307    }
308}