Skip to main content

ark_ff/biginteger/
arithmetic.rs

1use ark_std::{vec, vec::Vec};
2
3/// Sets a = a + b + carry, and returns the new carry.
4#[inline(always)]
5#[doc(hidden)]
6pub const fn adc(a: &mut u64, b: u64, carry: u64) -> u64 {
7    let tmp = *a as u128 + b as u128 + carry as u128;
8    *a = tmp as u64;
9    (tmp >> 64) as u64
10}
11
12/// Sets a = a + b + carry, and returns the new carry.
13#[inline(always)]
14#[doc(hidden)]
15pub fn adc_for_add_with_carry(a: &mut u64, b: u64, carry: u8) -> u8 {
16    #[cfg(all(target_arch = "x86_64", feature = "asm"))]
17    #[allow(unused_unsafe, unsafe_code)]
18    unsafe {
19        use core::arch::x86_64::_addcarry_u64;
20        _addcarry_u64(carry, *a, b, a)
21    }
22    #[cfg(not(all(target_arch = "x86_64", feature = "asm")))]
23    {
24        let tmp = *a as u128 + b as u128 + carry as u128;
25        *a = tmp as u64;
26        (tmp >> 64) as u8
27    }
28}
29
30/// Calculate a + b + carry, returning the sum
31#[inline(always)]
32#[doc(hidden)]
33pub const fn adc_no_carry(a: u64, b: u64, carry: &u64) -> u64 {
34    let tmp = a as u128 + b as u128 + *carry as u128;
35    tmp as u64
36}
37
38/// Sets a = a - b - borrow, and returns the borrow.
39#[inline(always)]
40pub(crate) const fn sbb(a: &mut u64, b: u64, borrow: u64) -> u64 {
41    let tmp = (1u128 << 64) + (*a as u128) - (b as u128) - (borrow as u128);
42    *a = tmp as u64;
43    (tmp >> 64 == 0) as u64
44}
45
46/// Sets a = a - b - borrow, and returns the borrow.
47#[inline(always)]
48#[doc(hidden)]
49pub fn sbb_for_sub_with_borrow(a: &mut u64, b: u64, borrow: u8) -> u8 {
50    #[cfg(target_arch = "x86_64")]
51    #[allow(unused_unsafe, unsafe_code)]
52    unsafe {
53        use core::arch::x86_64::_subborrow_u64;
54        _subborrow_u64(borrow, *a, b, a)
55    }
56    #[cfg(not(target_arch = "x86_64"))]
57    {
58        let tmp = (1u128 << 64) + (*a as u128) - (b as u128) - (borrow as u128);
59        *a = tmp as u64;
60        u8::from(tmp >> 64 == 0)
61    }
62}
63
64#[inline(always)]
65#[doc(hidden)]
66pub const fn widening_mul(a: u64, b: u64) -> u128 {
67    #[cfg(not(target_family = "wasm"))]
68    {
69        a as u128 * b as u128
70    }
71    #[cfg(target_family = "wasm")]
72    {
73        let a_lo = a as u32 as u64;
74        let a_hi = a >> 32;
75        let b_lo = b as u32 as u64;
76        let b_hi = b >> 32;
77
78        let lolo = (a_lo * b_lo) as u128;
79        let lohi = ((a_lo * b_hi) as u128) << 32;
80        let hilo = ((a_hi * b_lo) as u128) << 32;
81        let hihi = ((a_hi * b_hi) as u128) << 64;
82        (lolo | hihi) + (lohi + hilo)
83    }
84}
85
86/// Calculate a + b * c, returning the lower 64 bits of the result and setting
87/// `carry` to the upper 64 bits.
88#[inline(always)]
89#[doc(hidden)]
90pub const fn mac(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 {
91    let tmp = (a as u128) + widening_mul(b, c);
92    *carry = (tmp >> 64) as u64;
93    tmp as u64
94}
95
96/// Calculate a + b * c, discarding the lower 64 bits of the result and setting
97/// `carry` to the upper 64 bits.
98#[inline(always)]
99#[doc(hidden)]
100pub const fn mac_discard(a: u64, b: u64, c: u64, carry: &mut u64) {
101    let tmp = (a as u128) + widening_mul(b, c);
102    *carry = (tmp >> 64) as u64;
103}
104
105/// Calculate a + (b * c) + carry, returning the least significant digit
106/// and setting carry to the most significant digit.
107#[inline(always)]
108#[doc(hidden)]
109pub const fn mac_with_carry(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 {
110    let tmp = (a as u128) + widening_mul(b, c) + (*carry as u128);
111    *carry = (tmp >> 64) as u64;
112    tmp as u64
113}
114
115/// Compute the NAF (non-adjacent form) of num
116pub fn find_naf(num: &[u64]) -> Vec<i8> {
117    let mut num = num.to_vec();
118    let mut res = vec![];
119
120    // Helper functions for arithmetic operations
121    // Check if the number is non-zero
122    let is_non_zero = |num: &[u64]| num.iter().any(|&x| x != 0);
123    // Check if the number is odd
124    let is_odd = |num: &[u64]| num[0] & 1 == 1;
125    // Subtract a value `z` without borrow propagation
126    let sub_noborrow = |num: &mut [u64], z: u64| {
127        num.iter_mut()
128            .zip(ark_std::iter::once(z).chain(ark_std::iter::repeat(0)))
129            .fold(0, |borrow, (a, b)| sbb(a, b, borrow));
130    };
131    // Add a value `z` without carry propagation
132    let add_nocarry = |num: &mut [u64], z: u64| {
133        num.iter_mut()
134            .zip(ark_std::iter::once(z).chain(ark_std::iter::repeat(0)))
135            .fold(0, |carry, (a, b)| adc(a, b, carry));
136    };
137    // Perform an in-place division of the number by 2
138    let div2 = |num: &mut [u64]| {
139        num.iter_mut().rev().fold(0, |carry, x| {
140            let next_carry = *x << 63;
141            *x = (*x >> 1) | carry;
142            next_carry
143        });
144    };
145
146    // Main loop for NAF computation
147    while is_non_zero(&num) {
148        // Determine the current digit of the NAF representation
149        let z = if is_odd(&num) {
150            let z = 2 - (num[0] % 4) as i8;
151            if z >= 0 {
152                sub_noborrow(&mut num, z as u64);
153            } else {
154                add_nocarry(&mut num, (-z) as u64);
155            }
156            z
157        } else {
158            0
159        };
160
161        // Append the digit to the result
162        res.push(z);
163        // Divide the number by 2 for the next iteration
164        div2(&mut num);
165    }
166
167    res
168}
169
170/// We define relaxed NAF as a variant of NAF with a very small tweak.
171///
172/// Note that the cost of scalar multiplication grows with the length of the sequence (for doubling)
173/// plus the Hamming weight of the sequence (for addition, or subtraction).
174///
175/// NAF is optimizing for the Hamming weight only and therefore can be suboptimal.
176/// For example, NAF may generate a sequence (in little-endian) of the form ...0 -1 0 1.
177///
178/// This can be rewritten as ...0 1 1 to avoid one doubling, at the cost that we are making an
179/// exception of non-adjacence for the most significant bit.
180///
181/// Since this representation is no longer a strict NAF, we call it "relaxed NAF".
182pub fn find_relaxed_naf(num: &[u64]) -> Vec<i8> {
183    let mut res = find_naf(num);
184
185    let len = res.len();
186    if res[len - 2] == 0 && res[len - 3] == -1 {
187        res[len - 3] = 1;
188        res[len - 2] = 1;
189        res.resize(len - 1, 0);
190    }
191
192    res
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_adc() {
201        // Test addition without initial carry
202        let mut a = 5u64;
203        let carry = adc(&mut a, 10u64, 0);
204        assert_eq!(a, 15); // 5 + 10 = 15
205        assert_eq!(carry, 0); // No carry should be generated
206
207        // Test addition with carry when overflowing u64
208        let mut a = u64::MAX;
209        let carry = adc(&mut a, 1u64, 0);
210        assert_eq!(a, 0); // Overflow resets `a` to 0
211        assert_eq!(carry, 1); // Carry is 1 due to overflow
212
213        // Test addition with a non-zero initial carry
214        let mut a = 5u64;
215        let carry = adc(&mut a, 10u64, 1);
216        assert_eq!(a, 16); // 5 + 10 + 1 = 16
217        assert_eq!(carry, 0); // No overflow, so carry remains 0
218
219        // Test addition with carry and a large sum
220        let mut a = u64::MAX - 5;
221        let carry = adc(&mut a, 10u64, 1);
222        assert_eq!(a, 5); // (u64::MAX - 5 + 10 + 1) wraps around to 5
223        assert_eq!(carry, 1); // Carry is 1 due to overflow
224    }
225
226    #[test]
227    fn test_adc_for_add_with_carry() {
228        // Test addition without initial carry
229        let mut a = 5u64;
230        let carry = adc_for_add_with_carry(&mut a, 10u64, 0);
231        assert_eq!(a, 15); // Expect a to be 15
232        assert_eq!(carry, 0); // No carry should be generated
233
234        // Test addition with carry when overflowing u64
235        let mut a = u64::MAX;
236        let carry = adc_for_add_with_carry(&mut a, 1u64, 0);
237        assert_eq!(a, 0); // Overflow resets `a` to 0
238        assert_eq!(carry, 1); // Carry is 1 due to overflow
239
240        // Test addition with a non-zero initial carry
241        let mut a = 5u64;
242        let carry = adc_for_add_with_carry(&mut a, 10u64, 1);
243        assert_eq!(a, 16); // 5 + 10 + 1 = 16
244        assert_eq!(carry, 0); // No overflow, so carry remains 0
245
246        // Test addition with carry and a large sum
247        let mut a = u64::MAX - 5;
248        let carry = adc_for_add_with_carry(&mut a, 10u64, 1);
249        assert_eq!(a, 5); // (u64::MAX - 5 + 10 + 1) wraps around to 5
250        assert_eq!(carry, 1); // Carry is 1 due to overflow
251    }
252
253    #[test]
254    fn test_adc_no_carry() {
255        // Test addition without initial carry
256        let carry = 0;
257        let result = adc_no_carry(5u64, 10u64, &carry);
258        assert_eq!(result, 15); // 5 + 10 = 15
259        assert_eq!(carry, 0); // No carry should be generated
260
261        // Test addition with a non-zero initial carry
262        let carry = 1;
263        let result = adc_no_carry(5u64, 10u64, &carry);
264        assert_eq!(result, 16); // 5 + 10 + 1 = 16
265        assert_eq!(carry, 1); // No overflow, so carry remains 1
266
267        // Test addition that causes a carry
268        let carry = 1;
269        let result = adc_no_carry(u64::MAX, 1u64, &carry);
270        assert_eq!(result, 1); // u64::MAX + 1 + 1 -> 1
271        assert_eq!(carry, 1); // Carry is 1 due to overflow
272    }
273
274    #[test]
275    fn test_sbb() {
276        // Test subtraction without initial borrow
277        let mut a = 15u64;
278        let borrow = sbb(&mut a, 5u64, 0);
279        assert_eq!(a, 10); // 15 - 5 = 10
280        assert_eq!(borrow, 0); // No borrow should be generated
281
282        // Test subtraction that causes a borrow
283        let mut a = 5u64;
284        let borrow = sbb(&mut a, 10u64, 0);
285        assert_eq!(a, u64::MAX - 4); // Underflow, wrapping around
286        assert_eq!(borrow, 1); // Borrow should be 1
287
288        // Test subtraction with a non-zero initial borrow
289        let mut a = 15u64;
290        let borrow = sbb(&mut a, 5u64, 1);
291        assert_eq!(a, 9); // 15 - 5 - 1 = 9
292        assert_eq!(borrow, 0); // No borrow should be generated
293
294        // Test subtraction with borrow and a large value
295        let mut a = 0u64;
296        let borrow = sbb(&mut a, u64::MAX, 1);
297        assert_eq!(a, 0); // 0 - (u64::MAX + 1) -> 0
298        assert_eq!(borrow, 1); // Borrow should be 1
299    }
300
301    #[test]
302    fn test_sbb_for_sub_with_borrow() {
303        // Test subtraction without initial borrow
304        let mut a = 15u64;
305        let borrow = sbb_for_sub_with_borrow(&mut a, 5u64, 0);
306        assert_eq!(a, 10); // Expect a to be 10
307        assert_eq!(borrow, 0); // No borrow should be generated
308
309        // Test subtraction that causes a borrow
310        let mut a = 5u64;
311        let borrow = sbb_for_sub_with_borrow(&mut a, 10u64, 0);
312        assert_eq!(a, u64::MAX - 4); // Underflow, wrapping around
313        assert_eq!(borrow, 1); // Borrow should be 1
314
315        // Test subtraction with a non-zero initial borrow
316        let mut a = 15u64;
317        let borrow = sbb_for_sub_with_borrow(&mut a, 5u64, 1);
318        assert_eq!(a, 9); // 15 - 5 - 1 = 9
319        assert_eq!(borrow, 0); // No borrow should be generated
320
321        // Test subtraction with borrow and a large value
322        let mut a = 0u64;
323        let borrow = sbb_for_sub_with_borrow(&mut a, u64::MAX, 1);
324        assert_eq!(a, 0); // 0 - (u64::MAX + 1) -> 0
325        assert_eq!(borrow, 1); // Borrow should be 1
326    }
327
328    #[test]
329    fn test_mac() {
330        // Basic multiply-accumulate without carry
331        let mut carry = 0;
332        let result = mac(1u64, 2u64, 3u64, &mut carry);
333        assert_eq!(result, 7); // 1 + (2 * 3) = 7
334        assert_eq!(carry, 0); // No overflow, carry remains 0
335
336        // Multiply-accumulate with large values that generate a carry
337        let mut carry = 0;
338        let result = mac(u64::MAX, u64::MAX, 1u64, &mut carry);
339        assert_eq!(result, u64::MAX - 1); // Result wraps around
340        assert_eq!(carry, 1); // Carry is set due to overflow
341    }
342
343    #[test]
344    fn test_mac_discard() {
345        // Discard lower 64 bits and set carry
346        let mut carry = 0;
347        mac_discard(1u64, 2u64, 3u64, &mut carry);
348        assert_eq!(carry, 0); // No overflow, carry remains 0
349
350        // Test with values that generate a carry
351        let mut carry = 0;
352        mac_discard(u64::MAX, u64::MAX, 1u64, &mut carry);
353        assert_eq!(carry, 1); // Carry is set due to overflow
354    }
355
356    #[test]
357    fn test_mac_with_carry() {
358        // Basic multiply-accumulate with carry
359        let mut carry = 1;
360        let result = mac_with_carry(1u64, 2u64, 3u64, &mut carry);
361        assert_eq!(result, 8); // 1 + (2 * 3) + 1 = 8
362        assert_eq!(carry, 0); // No overflow, carry remains 0
363
364        // Multiply-accumulate with carry and large values
365        let mut carry = 1;
366        let result = mac_with_carry(u64::MAX, u64::MAX, 1u64, &mut carry);
367        assert_eq!(result, u64::MAX); // Result wraps around
368        assert_eq!(carry, 1); // Carry is set due to overflow
369    }
370
371    #[test]
372    fn test_find_relaxed_naf_usefulness() {
373        let vec = find_naf(&[12u64]);
374        assert_eq!(vec.len(), 5);
375
376        let vec = find_relaxed_naf(&[12u64]);
377        assert_eq!(vec.len(), 4);
378    }
379
380    #[test]
381    fn test_find_relaxed_naf_correctness() {
382        use ark_std::{One, UniformRand, Zero};
383        use num_bigint::BigInt;
384
385        let mut rng = ark_std::test_rng();
386
387        for _ in 0..10 {
388            let num = [
389                u64::rand(&mut rng),
390                u64::rand(&mut rng),
391                u64::rand(&mut rng),
392                u64::rand(&mut rng),
393            ];
394            let relaxed_naf = find_relaxed_naf(&num);
395
396            let test = {
397                let mut sum = BigInt::zero();
398                let mut cur = BigInt::one();
399                for v in relaxed_naf {
400                    sum += cur.clone() * v;
401                    cur *= 2;
402                }
403                sum
404            };
405
406            let test_expected = {
407                let mut sum = BigInt::zero();
408                let mut cur = BigInt::one();
409                for v in num.iter() {
410                    sum += cur.clone() * v;
411                    cur <<= 64;
412                }
413                sum
414            };
415
416            assert_eq!(test, test_expected);
417        }
418    }
419
420    #[test]
421    fn test_find_naf_zero() {
422        // Test for zero input
423        let naf = find_naf(&[0]);
424        assert!(naf.is_empty());
425    }
426
427    #[test]
428    fn test_find_naf_single_digit() {
429        // Test for small numbers
430        assert_eq!(find_naf(&[1]), vec![1]);
431        assert_eq!(find_naf(&[2]), vec![0, 1]);
432        assert_eq!(find_naf(&[3]), vec![-1, 0, 1]);
433        assert_eq!(find_naf(&[4]), vec![0, 0, 1]);
434    }
435
436    #[test]
437    fn test_find_naf_large_number() {
438        // Test for a larger number
439        assert_eq!(find_naf(&[13]), vec![1, 0, -1, 0, 1]);
440    }
441
442    #[test]
443    fn test_find_naf_multiple_blocks() {
444        // Test multi-block number (simulate large numbers split across blocks)
445        let num = [0, 1];
446        assert_eq!(
447            find_naf(&num),
448            vec![
449                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
450                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
451                0, 0, 0, 0, 0, 0, 0, 0, 1
452            ]
453        );
454    }
455
456    #[test]
457    fn test_find_naf_edge_cases() {
458        // Test edge cases
459        let naf = find_naf(&[u64::MAX]);
460        assert!(!naf.is_empty());
461    }
462}