Skip to main content

blvm_secp256k1/
scalar.rs

1//! Scalar arithmetic modulo the secp256k1 group order n.
2//!
3//! n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
4
5#[cfg(target_arch = "x86_64")]
6mod scalar_asm {
7    use super::Scalar;
8
9    extern "C" {
10        /// libsecp256k1 scalar_mul_512: l8 = a * b (512-bit product).
11        /// SysV: rdi=l8, rsi=a, rdx=b.
12        fn blvm_secp256k1_scalar_mul_512(l8: *mut u64, a: *const Scalar, b: *const Scalar);
13
14        /// libsecp256k1 scalar_reduce_512: reduce 512-bit l mod n into r.
15        /// Returns overflow for final reduction. SysV: rdi=r, rsi=l.
16        fn blvm_secp256k1_scalar_reduce_512(r: *mut Scalar, l: *const u64) -> u64;
17    }
18
19    #[inline(always)]
20    pub(super) unsafe fn scalar_mul_512_asm(l: *mut u64, a: *const Scalar, b: *const Scalar) {
21        blvm_secp256k1_scalar_mul_512(l, a, b);
22    }
23
24    #[inline(always)]
25    pub(super) unsafe fn scalar_reduce_512_asm(r: *mut Scalar, l: *const u64) -> u64 {
26        blvm_secp256k1_scalar_reduce_512(r, l)
27    }
28}
29
30use num_bigint::BigUint;
31use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
32
33/// Scalar modulo group order n. 4x64 limb layout (x86_64, aarch64).
34#[repr(C)]
35#[derive(Clone, Copy, Debug)]
36pub struct Scalar {
37    pub d: [u64; 4],
38}
39
40// secp256k1 group order n
41const N_0: u64 = 0xBFD25E8CD0364141;
42const N_1: u64 = 0xBAAEDCE6AF48A03B;
43const N_2: u64 = 0xFFFFFFFFFFFFFFFE;
44const N_3: u64 = 0xFFFFFFFFFFFFFFFF;
45
46// 2^256 - n (for reduction)
47const N_C_0: u64 = 0x402DA1732FC9BEBF;
48const N_C_1: u64 = 0x4551231950B75FC4;
49const N_C_2: u64 = 1;
50
51// n/2 (for is_high)
52const N_H_0: u64 = 0xDFE92F46681B20A0;
53const N_H_1: u64 = 0x5D576E7357A4501D;
54const N_H_2: u64 = 0xFFFFFFFFFFFFFFFF;
55const N_H_3: u64 = 0x7FFFFFFFFFFFFFFF;
56
57// modulus n (for safegcd inv)
58#[allow(dead_code)]
59const N: Scalar = Scalar {
60    d: [N_0, N_1, N_2, N_3],
61};
62
63const LAMBDA: Scalar = Scalar {
64    d: [
65        0xDF02967C1B23BD72,
66        0x122E22EA20816678,
67        0xA5261C028812645A,
68        0x5363AD4CC05C30E0,
69    ],
70};
71
72impl Scalar {
73    pub fn zero() -> Self {
74        Self { d: [0, 0, 0, 0] }
75    }
76
77    pub fn one() -> Self {
78        Self { d: [1, 0, 0, 0] }
79    }
80
81    pub fn set_int(&mut self, v: u32) {
82        self.d[0] = v as u64;
83        self.d[1] = 0;
84        self.d[2] = 0;
85        self.d[3] = 0;
86    }
87
88    /// Set from 32-byte big-endian. Reduces mod n.
89    pub fn set_b32(&mut self, bin: &[u8; 32]) -> bool {
90        self.d[0] = read_be64(&bin[24..32]);
91        self.d[1] = read_be64(&bin[16..24]);
92        self.d[2] = read_be64(&bin[8..16]);
93        self.d[3] = read_be64(&bin[0..8]);
94        let overflow = self.check_overflow();
95        self.reduce(overflow as u64);
96        overflow
97    }
98
99    pub fn get_b32(&self, bin: &mut [u8; 32]) {
100        write_be64(&mut bin[0..8], self.d[3]);
101        write_be64(&mut bin[8..16], self.d[2]);
102        write_be64(&mut bin[16..24], self.d[1]);
103        write_be64(&mut bin[24..32], self.d[0]);
104    }
105
106    fn check_overflow(&self) -> bool {
107        let mut yes = 0u64;
108        let mut no = 0u64;
109        no |= (self.d[3] < N_3) as u64;
110        no |= (self.d[2] < N_2) as u64;
111        yes |= (self.d[2] > N_2) as u64 & !no;
112        no |= (self.d[1] < N_1) as u64;
113        yes |= (self.d[1] > N_1) as u64 & !no;
114        yes |= (self.d[0] >= N_0) as u64 & !no;
115        yes != 0
116    }
117
118    fn reduce(&mut self, overflow: u64) {
119        let mut t: u128 = self.d[0] as u128 + (overflow as u128 * N_C_0 as u128);
120        self.d[0] = t as u64;
121        t >>= 64;
122        t += self.d[1] as u128 + (overflow as u128 * N_C_1 as u128);
123        self.d[1] = t as u64;
124        t >>= 64;
125        t += self.d[2] as u128 + (overflow as u128 * N_C_2 as u128);
126        self.d[2] = t as u64;
127        t >>= 64;
128        t += self.d[3] as u128;
129        self.d[3] = t as u64;
130    }
131
132    pub fn is_zero(&self) -> bool {
133        (self.d[0] | self.d[1] | self.d[2] | self.d[3]) == 0
134    }
135
136    pub fn is_one(&self) -> bool {
137        (self.d[0] ^ 1) | self.d[1] | self.d[2] | self.d[3] == 0
138    }
139
140    /// True if scalar is odd (d[0] & 1).
141    #[allow(dead_code)]
142    fn is_odd(&self) -> bool {
143        self.d[0] & 1 != 0
144    }
145
146    /// True if scalar is even. Used by wnaf_fixed (Pippenger).
147    pub(crate) fn is_even(&self) -> bool {
148        self.d[0] & 1 == 0
149    }
150
151    /// self = a - b (mod n). Result in [0, n-1].
152    #[allow(dead_code)]
153    fn sub(&mut self, a: &Scalar, b: &Scalar) {
154        let mut neg_b = Scalar::zero();
155        neg_b.negate(b);
156        self.add(a, &neg_b);
157    }
158
159    /// self = self / 2. Only valid when self is even.
160    #[allow(dead_code)]
161    fn half(&mut self) {
162        self.d[0] = (self.d[0] >> 1) | (self.d[1] << 63);
163        self.d[1] = (self.d[1] >> 1) | (self.d[2] << 63);
164        self.d[2] = (self.d[2] >> 1) | (self.d[3] << 63);
165        self.d[3] >>= 1;
166    }
167
168    /// self = (self + n) / 2. Only valid when self is odd. Result in [0, n-1].
169    #[allow(dead_code)]
170    fn half_add_n(&mut self) {
171        let mut t: u128 = self.d[0] as u128 + N_0 as u128;
172        let c0 = t as u64;
173        let mut c1 = (t >> 64) as u64;
174        t = self.d[1] as u128 + N_1 as u128 + c1 as u128;
175        c1 = t as u64;
176        let mut c2 = (t >> 64) as u64;
177        t = self.d[2] as u128 + N_2 as u128 + c2 as u128;
178        c2 = t as u64;
179        let mut c3 = (t >> 64) as u64;
180        t = self.d[3] as u128 + N_3 as u128 + c3 as u128;
181        c3 = t as u64;
182        let c4 = (t >> 64) as u64;
183        self.d[0] = (c0 >> 1) | (c1 << 63);
184        self.d[1] = (c1 >> 1) | (c2 << 63);
185        self.d[2] = (c2 >> 1) | (c3 << 63);
186        self.d[3] = (c3 >> 1) | (c4 << 63);
187        self.reduce(self.check_overflow() as u64);
188    }
189
190    /// div2(M, x): x/2 mod n when x even, (x+n)/2 mod n when x odd.
191    #[allow(dead_code)]
192    fn div2(&mut self) {
193        if self.is_odd() {
194            self.half_add_n();
195        } else {
196            self.half();
197        }
198    }
199
200    /// tmp = a + b (full 257-bit add, no reduction). Used when both are odd and we need (a+b)/2.
201    #[allow(dead_code)]
202    fn add_no_reduce(a: &Scalar, b: &Scalar) -> [u64; 5] {
203        let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
204        let c0 = t as u64;
205        let mut c1 = (t >> 64) as u64;
206        t = a.d[1] as u128 + b.d[1] as u128 + c1 as u128;
207        c1 = t as u64;
208        let mut c2 = (t >> 64) as u64;
209        t = a.d[2] as u128 + b.d[2] as u128 + c2 as u128;
210        c2 = t as u64;
211        let mut c3 = (t >> 64) as u64;
212        t = a.d[3] as u128 + b.d[3] as u128 + c3 as u128;
213        c3 = t as u64;
214        let c4 = (t >> 64) as u64;
215        [c0, c1, c2, c3, c4]
216    }
217
218    /// self = (c0..c4) >> 1, then reduce mod n.
219    #[allow(dead_code)]
220    fn set_from_5limb_half(&mut self, c: &[u64; 5]) {
221        self.d[0] = (c[0] >> 1) | (c[1] << 63);
222        self.d[1] = (c[1] >> 1) | (c[2] << 63);
223        self.d[2] = (c[2] >> 1) | (c[3] << 63);
224        self.d[3] = (c[3] >> 1) | (c[4] << 63);
225        self.reduce(self.check_overflow() as u64);
226    }
227
228    /// self = (a - b) >> 1 mod n. a and b odd. When a>=b, a-b is even; when a<b, a-b+n is odd, div2 adds n.
229    #[allow(dead_code)]
230    fn sub_half(&mut self, a: &Scalar, b: &Scalar) {
231        self.sub(a, b);
232        self.div2();
233    }
234
235    pub fn add(&mut self, a: &Scalar, b: &Scalar) -> bool {
236        let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
237        self.d[0] = t as u64;
238        t >>= 64;
239        t += a.d[1] as u128 + b.d[1] as u128;
240        self.d[1] = t as u64;
241        t >>= 64;
242        t += a.d[2] as u128 + b.d[2] as u128;
243        self.d[2] = t as u64;
244        t >>= 64;
245        t += a.d[3] as u128 + b.d[3] as u128;
246        self.d[3] = t as u64;
247        t >>= 64;
248        let overflow = t as u64 + self.check_overflow() as u64;
249        debug_assert!(overflow <= 1);
250        self.reduce(overflow);
251        overflow != 0
252    }
253
254    pub fn negate(&mut self, a: &Scalar) {
255        let nonzero = if a.is_zero() { 0u64 } else { u64::MAX };
256        let mut t: u128 = (!a.d[0]) as u128 + (N_0 + 1) as u128;
257        self.d[0] = (t as u64) & nonzero;
258        t >>= 64;
259        t += (!a.d[1]) as u128 + N_1 as u128;
260        self.d[1] = (t as u64) & nonzero;
261        t >>= 64;
262        t += (!a.d[2]) as u128 + N_2 as u128;
263        self.d[2] = (t as u64) & nonzero;
264        t >>= 64;
265        t += (!a.d[3]) as u128 + N_3 as u128;
266        self.d[3] = (t as u64) & nonzero;
267    }
268
269    pub fn mul(&mut self, a: &Scalar, b: &Scalar) {
270        let mut l = [0u64; 8];
271        scalar_mul_512(&mut l, a, b);
272        scalar_reduce_512(self, &l);
273    }
274
275    /// split_lambda: find r1, r2 such that r1 + r2*lambda == k (mod n)
276    pub fn split_lambda(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
277        const MINUS_B1: Scalar = Scalar {
278            d: [
279                (0x6F547FA9u64 << 32) | 0x0ABFE4C3,
280                (0xE4437ED6u64 << 32) | 0x010E8828,
281                0,
282                0,
283            ],
284        };
285        const MINUS_B2: Scalar = Scalar {
286            d: [
287                (0xD765CDA8u64 << 32) | 0x3DB1562C,
288                (0x8A280AC5u64 << 32) | 0x0774346D,
289                (0xFFFFFFFFu64 << 32) | 0xFFFFFFFE,
290                (0xFFFFFFFFu64 << 32) | 0xFFFFFFFF,
291            ],
292        };
293        const G1: Scalar = Scalar {
294            d: [
295                (0xE893209Au64 << 32) | 0x45DBB031,
296                (0x3DAA8A14u64 << 32) | 0x71E8CA7F,
297                (0xE86C90E4u64 << 32) | 0x9284EB15,
298                (0x3086D221u64 << 32) | 0xA7D46BCD,
299            ],
300        };
301        const G2: Scalar = Scalar {
302            d: [
303                (0x1571B4AEu64 << 32) | 0x8AC47F71,
304                (0x221208ACu64 << 32) | 0x9DF506C6,
305                (0x6F547FA9u64 << 32) | 0x0ABFE4C4,
306                (0xE4437ED6u64 << 32) | 0x010E8828,
307            ],
308        };
309
310        let mut c1 = Scalar::zero();
311        let mut c2 = Scalar::zero();
312        scalar_mul_shift_var(&mut c1, k, &G1, 384);
313        scalar_mul_shift_var(&mut c2, k, &G2, 384);
314        let mut t = Scalar::zero();
315        t.mul(&c1, &MINUS_B1);
316        c1 = t;
317        t.mul(&c2, &MINUS_B2);
318        c2 = t;
319        r2.add(&c1, &c2);
320        r1.mul(r2, &LAMBDA);
321        let mut neg = Scalar::zero();
322        neg.negate(r1);
323        r1.add(&neg, k);
324    }
325
326    /// Extract `count` bits at `offset` (0..256). Count in [1,32].
327    /// For single-limb (offset and offset+count-1 in same u64) uses fast path.
328    pub fn get_bits_limb32(&self, offset: u32, count: u32) -> u32 {
329        debug_assert!(count > 0 && count <= 32);
330        debug_assert!((offset + count - 1) >> 6 == offset >> 6);
331        let limb = offset >> 6;
332        let shift = offset & 0x3F;
333        let mask = if count == 32 {
334            u32::MAX
335        } else {
336            (1u32 << count) - 1
337        };
338        ((self.d[limb as usize] >> shift) as u32) & mask
339    }
340
341    /// Extract `count` bits at `offset`. Count in [1,32], offset+count <= 256.
342    pub fn get_bits_var(&self, offset: u32, count: u32) -> u32 {
343        debug_assert!(count > 0 && count <= 32);
344        debug_assert!(offset + count <= 256);
345        if (offset + count - 1) >> 6 == offset >> 6 {
346            self.get_bits_limb32(offset, count)
347        } else {
348            let limb = (offset >> 6) as usize;
349            let shift = offset & 0x3F;
350            let mask = if count == 32 {
351                u32::MAX
352            } else {
353                (1u32 << count) - 1
354            };
355            let lo = self.d[limb] >> shift;
356            let hi = self.d[limb + 1].wrapping_shl(64u32 - shift);
357            ((lo | hi) as u32) & mask
358        }
359    }
360
361    pub fn split_128(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
362        r1.d[0] = k.d[0];
363        r1.d[1] = k.d[1];
364        r1.d[2] = 0;
365        r1.d[3] = 0;
366        r2.d[0] = k.d[2];
367        r2.d[1] = k.d[3];
368        r2.d[2] = 0;
369        r2.d[3] = 0;
370    }
371
372    /// Modular inverse. Variable-time. r = a^(-1) mod n. If a is zero, r is zero.
373    /// On x86_64/aarch64: modinv64 (safegcd). Else: Fermat via num-bigint modpow.
374    pub fn inv_var(&mut self, a: &Scalar) {
375        if a.is_zero() {
376            *self = Scalar::zero();
377            return;
378        }
379        #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
380        {
381            use crate::modinv64::{modinv64, SECP256K1_SCALAR_MODINV_MODINFO};
382            let mut x = scalar_to_signed62(a);
383            modinv64(&mut x, &SECP256K1_SCALAR_MODINV_MODINFO);
384            scalar_from_signed62(self, &x);
385        }
386        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
387        {
388            let a_big = scalar_to_biguint(a);
389            let n_big = scalar_to_biguint(&N);
390            let exp = &n_big - 2u32;
391            let inv_big = a_big.modpow(&exp, &n_big);
392            biguint_to_scalar(self, &inv_big);
393        }
394    }
395
396    /// True if scalar is in the upper half [n/2, n).
397    pub fn is_high(&self) -> bool {
398        let mut yes = 0u64;
399        let mut no = 0u64;
400        no |= (self.d[3] < N_H_3) as u64;
401        yes |= (self.d[3] > N_H_3) as u64 & !no;
402        no |= (self.d[2] < N_H_2) as u64 & !yes;
403        no |= (self.d[1] < N_H_1) as u64 & !yes;
404        yes |= (self.d[1] > N_H_1) as u64 & !no;
405        yes |= (self.d[0] > N_H_0) as u64 & !no;
406        yes != 0
407    }
408
409    /// Conditionally negate: if flag != 0, negate in place. Returns 1 if negated, -1 if not.
410    pub fn cond_negate(&mut self, flag: i32) -> i32 {
411        let mask = if flag != 0 { u64::MAX } else { 0 };
412        let nonzero = if self.is_zero() { 0 } else { u64::MAX };
413        let mut t: u128 = (self.d[0] ^ mask) as u128;
414        t += ((N_0 + 1) & mask) as u128;
415        self.d[0] = (t as u64) & nonzero;
416        t >>= 64;
417        t += (self.d[1] ^ mask) as u128;
418        t += (N_1 & mask) as u128;
419        self.d[1] = (t as u64) & nonzero;
420        t >>= 64;
421        t += (self.d[2] ^ mask) as u128;
422        t += (N_2 & mask) as u128;
423        self.d[2] = (t as u64) & nonzero;
424        t >>= 64;
425        t += (self.d[3] ^ mask) as u128;
426        t += (N_3 & mask) as u128;
427        self.d[3] = (t as u64) & nonzero;
428        if mask == 0 {
429            -1
430        } else {
431            1
432        }
433    }
434}
435
436impl ConstantTimeEq for Scalar {
437    fn ct_eq(&self, other: &Self) -> Choice {
438        self.d[0].ct_eq(&other.d[0])
439            & self.d[1].ct_eq(&other.d[1])
440            & self.d[2].ct_eq(&other.d[2])
441            & self.d[3].ct_eq(&other.d[3])
442    }
443}
444
445impl ConditionallySelectable for Scalar {
446    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
447        Self {
448            d: [
449                u64::conditional_select(&a.d[0], &b.d[0], choice),
450                u64::conditional_select(&a.d[1], &b.d[1], choice),
451                u64::conditional_select(&a.d[2], &b.d[2], choice),
452                u64::conditional_select(&a.d[3], &b.d[3], choice),
453            ],
454        }
455    }
456}
457
458/// Pack 4×64 scalar limbs into 5×62 signed limbs for modinv64.
459#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
460fn scalar_to_signed62(a: &Scalar) -> crate::modinv64::Signed62 {
461    const M62: u64 = u64::MAX >> 2;
462    let d = &a.d;
463    crate::modinv64::Signed62 {
464        v: [
465            (d[0] & M62) as i64,
466            ((d[0] >> 62 | d[1] << 2) & M62) as i64,
467            ((d[1] >> 60 | d[2] << 4) & M62) as i64,
468            ((d[2] >> 58 | d[3] << 6) & M62) as i64,
469            (d[3] >> 56) as i64,
470        ],
471    }
472}
473
474/// Unpack 5×62 signed limbs back to 4×64 scalar limbs.
475#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
476fn scalar_from_signed62(r: &mut Scalar, a: &crate::modinv64::Signed62) {
477    let v = &a.v;
478    r.d[0] = (v[0] as u64) | ((v[1] as u64) << 62);
479    r.d[1] = ((v[1] as u64) >> 2) | ((v[2] as u64) << 60);
480    r.d[2] = ((v[2] as u64) >> 4) | ((v[3] as u64) << 58);
481    r.d[3] = ((v[3] as u64) >> 6) | ((v[4] as u64) << 56);
482}
483
484#[allow(dead_code)]
485fn scalar_to_biguint(s: &Scalar) -> BigUint {
486    let mut bytes = [0u8; 32];
487    s.get_b32(&mut bytes);
488    BigUint::from_bytes_be(&bytes)
489}
490
491#[allow(dead_code)]
492fn biguint_to_scalar(r: &mut Scalar, b: &BigUint) {
493    let bytes = b.to_bytes_be();
494    let mut buf = [0u8; 32];
495    let len = bytes.len().min(32);
496    let start = 32 - len;
497    buf[start..].copy_from_slice(&bytes[..len]);
498    r.set_b32(&buf);
499}
500
501fn read_be64(b: &[u8]) -> u64 {
502    u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
503}
504
505fn write_be64(b: &mut [u8], v: u64) {
506    b[0..8].copy_from_slice(&v.to_be_bytes());
507}
508
509fn scalar_mul_512(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
510    #[cfg(target_arch = "x86_64")]
511    {
512        unsafe {
513            scalar_asm::scalar_mul_512_asm(l.as_mut_ptr(), a, b);
514        }
515    }
516    #[cfg(not(target_arch = "x86_64"))]
517    {
518        scalar_mul_512_rust(l, a, b);
519    }
520}
521
522#[cfg(not(target_arch = "x86_64"))]
523fn scalar_mul_512_rust(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
524    let mut c0: u64 = 0;
525    let mut c1: u64 = 0;
526    let mut c2: u32 = 0;
527
528    macro_rules! muladd_fast {
529        ($a:expr, $b:expr) => {{
530            let prod = ($a as u128) * ($b as u128);
531            let prod_lo = prod as u64;
532            let prod_hi = (prod >> 64) as u64;
533            let (lo, o) = c0.overflowing_add(prod_lo);
534            c0 = lo;
535            c1 += prod_hi + o as u64; // C: th = prod_hi + (c0 < tl)
536        }};
537    }
538    macro_rules! muladd {
539        ($a:expr, $b:expr) => {{
540            let prod = ($a as u128) * ($b as u128);
541            let hi = (prod >> 64) as u64;
542            let (lo, o1) = c0.overflowing_add(prod as u64);
543            c0 = lo;
544            let th = hi + o1 as u64;
545            let (mid, o2) = c1.overflowing_add(th);
546            c1 = mid;
547            c2 += o2 as u32;
548        }};
549    }
550    macro_rules! sumadd {
551        ($a:expr) => {{
552            let (lo, o) = c0.overflowing_add($a);
553            c0 = lo;
554            c1 += o as u64;
555            c2 += (c1 == 0 && o) as u32;
556        }};
557    }
558    macro_rules! extract {
559        () => {{
560            let n = c0;
561            c0 = c1;
562            c1 = c2 as u64;
563            c2 = 0;
564            n
565        }};
566    }
567    macro_rules! extract_fast {
568        () => {{
569            let n = c0;
570            c0 = c1;
571            c1 = 0;
572            n
573        }};
574    }
575
576    muladd_fast!(a.d[0], b.d[0]);
577    l[0] = extract_fast!();
578    muladd!(a.d[0], b.d[1]);
579    muladd!(a.d[1], b.d[0]);
580    l[1] = extract!();
581    muladd!(a.d[0], b.d[2]);
582    muladd!(a.d[1], b.d[1]);
583    muladd!(a.d[2], b.d[0]);
584    l[2] = extract!();
585    muladd!(a.d[0], b.d[3]);
586    muladd!(a.d[1], b.d[2]);
587    muladd!(a.d[2], b.d[1]);
588    muladd!(a.d[3], b.d[0]);
589    l[3] = extract!();
590    muladd!(a.d[1], b.d[3]);
591    muladd!(a.d[2], b.d[2]);
592    muladd!(a.d[3], b.d[1]);
593    l[4] = extract!();
594    muladd!(a.d[2], b.d[3]);
595    muladd!(a.d[3], b.d[2]);
596    l[5] = extract!();
597    muladd_fast!(a.d[3], b.d[3]);
598    l[6] = extract_fast!();
599    l[7] = c0;
600}
601
602#[allow(dead_code)]
603fn limbs_512_to_biguint(l: &[u64; 8]) -> BigUint {
604    let mut acc = BigUint::from(0u64);
605    for (i, &limb) in l.iter().enumerate() {
606        acc += BigUint::from(limb) << (64 * i);
607    }
608    acc
609}
610
611/// Limb-based 512→256 reduction mod n. Replaces BigUint for hot path.
612/// Port of libsecp256k1 scalar_reduce_512 C fallback (muladd/extract).
613#[cfg(not(target_arch = "x86_64"))]
614fn scalar_reduce_512_limbs(r: &mut Scalar, l: &[u64; 8]) {
615    let n0 = l[4];
616    let n1 = l[5];
617    let n2 = l[6];
618    let n3 = l[7];
619
620    let mut c0: u64 = l[0];
621    let mut c1: u64 = 0;
622    let mut c2: u32 = 0;
623
624    macro_rules! muladd_fast {
625        ($a:expr, $b:expr) => {{
626            let prod = ($a as u128) * ($b as u128);
627            let (lo, o) = c0.overflowing_add(prod as u64);
628            c0 = lo;
629            c1 += (prod >> 64) as u64 + o as u64;
630        }};
631    }
632    macro_rules! muladd {
633        ($a:expr, $b:expr) => {{
634            let prod = ($a as u128) * ($b as u128);
635            let (lo, o1) = c0.overflowing_add(prod as u64);
636            c0 = lo;
637            let th = (prod >> 64) as u64 + o1 as u64;
638            let (mid, o2) = c1.overflowing_add(th);
639            c1 = mid;
640            c2 += o2 as u32;
641        }};
642    }
643    macro_rules! sumadd_fast {
644        ($a:expr) => {{
645            let (lo, o) = c0.overflowing_add($a);
646            c0 = lo;
647            c1 += o as u64;
648        }};
649    }
650    macro_rules! sumadd {
651        ($a:expr) => {{
652            let (lo, o) = c0.overflowing_add($a);
653            c0 = lo;
654            let (mid, o2) = c1.overflowing_add(o as u64);
655            c1 = mid;
656            c2 += o2 as u32;
657        }};
658    }
659    macro_rules! extract {
660        () => {{
661            let n = c0;
662            c0 = c1;
663            c1 = c2 as u64;
664            c2 = 0;
665            n
666        }};
667    }
668    macro_rules! extract_fast {
669        () => {{
670            let n = c0;
671            c0 = c1;
672            c1 = 0;
673            n
674        }};
675    }
676
677    // Reduce 512 bits into 385: m[0..6] = l[0..3] + n[0..3] * N_C
678    muladd_fast!(n0, N_C_0);
679    let m0 = extract_fast!();
680    sumadd_fast!(l[1]);
681    muladd!(n1, N_C_0);
682    muladd!(n0, N_C_1);
683    let m1 = extract!();
684    sumadd!(l[2]);
685    muladd!(n2, N_C_0);
686    muladd!(n1, N_C_1);
687    sumadd!(n0);
688    let m2 = extract!();
689    sumadd!(l[3]);
690    muladd!(n3, N_C_0);
691    muladd!(n2, N_C_1);
692    sumadd!(n1);
693    let m3 = extract!();
694    muladd!(n3, N_C_1);
695    sumadd!(n2);
696    let m4 = extract!();
697    sumadd_fast!(n3);
698    let m5 = extract_fast!();
699    let m6 = c0 as u32;
700
701    // Reduce 385 into 258: p[0..4] = m[0..3] + m[4..6] * N_C
702    c0 = m0;
703    c1 = 0;
704    c2 = 0;
705    muladd_fast!(m4, N_C_0);
706    let p0 = extract_fast!();
707    sumadd_fast!(m1);
708    muladd!(m5, N_C_0);
709    muladd!(m4, N_C_1);
710    let p1 = extract!();
711    sumadd!(m2);
712    muladd!(m6 as u64, N_C_0);
713    muladd!(m5, N_C_1);
714    sumadd!(m4);
715    let p2 = extract!();
716    sumadd_fast!(m3);
717    muladd_fast!(m6 as u64, N_C_1);
718    sumadd_fast!(m5);
719    let p3 = extract_fast!();
720    let p4 = (c0 + m6 as u64) as u32;
721
722    // Reduce 258 into 256: r = p[0..3] + p4 * N_C
723    let mut t: u128 = p0 as u128;
724    t += (N_C_0 as u128) * (p4 as u128);
725    r.d[0] = t as u64;
726    t >>= 64;
727    t += p1 as u128;
728    t += (N_C_1 as u128) * (p4 as u128);
729    r.d[1] = t as u64;
730    t >>= 64;
731    t += p2 as u128;
732    t += p4 as u128;
733    r.d[2] = t as u64;
734    t >>= 64;
735    t += p3 as u128;
736    r.d[3] = t as u64;
737    let c = (t >> 64) as u64;
738
739    // Final reduction
740    scalar_reduce(r, c + scalar_check_overflow(r));
741}
742
743/// Add overflow*N_C to r. overflow is 0 or 1.
744fn scalar_reduce(r: &mut Scalar, overflow: u64) {
745    let of = overflow as u128;
746    let mut t: u128 = r.d[0] as u128;
747    t += of * (N_C_0 as u128);
748    r.d[0] = t as u64;
749    t >>= 64;
750    t += r.d[1] as u128;
751    t += of * (N_C_1 as u128);
752    r.d[1] = t as u64;
753    t >>= 64;
754    t += r.d[2] as u128;
755    t += of * (N_C_2 as u128);
756    r.d[2] = t as u64;
757    t >>= 64;
758    r.d[3] = (t as u64).wrapping_add(r.d[3]);
759}
760
761/// Returns 1 if r >= n, else 0.
762fn scalar_check_overflow(r: &Scalar) -> u64 {
763    let mut yes = 0u64;
764    let mut no = 0u64;
765    no |= (r.d[3] < N_3) as u64;
766    no |= (r.d[2] < N_2) as u64;
767    yes |= (r.d[2] > N_2) as u64 & !no;
768    no |= (r.d[1] < N_1) as u64;
769    yes |= (r.d[1] > N_1) as u64 & !no;
770    yes |= (r.d[0] >= N_0) as u64 & !no;
771    yes
772}
773
774fn scalar_reduce_512(r: &mut Scalar, l: &[u64; 8]) {
775    #[cfg(target_arch = "x86_64")]
776    {
777        let c = unsafe { scalar_asm::scalar_reduce_512_asm(r, l.as_ptr()) };
778        scalar_reduce(r, c + scalar_check_overflow(r));
779    }
780    #[cfg(not(target_arch = "x86_64"))]
781    {
782        scalar_reduce_512_limbs(r, l);
783    }
784}
785
786#[cfg(test)]
787#[test]
788fn test_scalar_reduce_n_plus_1() {
789    let l = [N_0 + 1, N_1, N_2, N_3, 0, 0, 0, 0];
790    let mut r = Scalar::zero();
791    scalar_reduce_512(&mut r, &l);
792    assert!(r.is_one(), "(n+1) mod n = 1, got r.d = {:?}", r.d);
793}
794
795#[cfg(test)]
796#[test]
797fn test_scalar_mul_inv2_times_2() {
798    let inv2_hex = "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1";
799    let inv2_bytes = hex::decode(inv2_hex).unwrap();
800    let mut buf = [0u8; 32];
801    buf.copy_from_slice(&inv2_bytes);
802    let mut inv2 = Scalar::zero();
803    inv2.set_b32(&buf);
804    let mut two = Scalar::zero();
805    two.set_int(2);
806    let mut l = [0u64; 8];
807    scalar_mul_512(&mut l, &inv2, &two);
808    let mut r = Scalar::zero();
809    scalar_reduce_512(&mut r, &l);
810    assert!(r.is_one(), "inv2*2 mod n = 1");
811}
812
813fn scalar_mul_shift_var(r: &mut Scalar, a: &Scalar, b: &Scalar, shift: u32) {
814    assert!(shift >= 256);
815    let mut l = [0u64; 8];
816    scalar_mul_512(&mut l, a, b);
817    let shiftlimbs = (shift >> 6) as usize;
818    let shiftlow = shift & 0x3F;
819    let shifthigh = 64 - shiftlow;
820    r.d[0] = if shift < 512 {
821        (l[shiftlimbs] >> shiftlow)
822            | (if shift < 448 && shiftlow != 0 {
823                l[1 + shiftlimbs] << shifthigh
824            } else {
825                0
826            })
827    } else {
828        0
829    };
830    r.d[1] = if shift < 448 {
831        (l[1 + shiftlimbs] >> shiftlow)
832            | (if shift < 384 && shiftlow != 0 {
833                l[2 + shiftlimbs] << shifthigh
834            } else {
835                0
836            })
837    } else {
838        0
839    };
840    r.d[2] = if shift < 384 {
841        (l[2 + shiftlimbs] >> shiftlow)
842            | (if shift < 320 && shiftlow != 0 {
843                l[3 + shiftlimbs] << shifthigh
844            } else {
845                0
846            })
847    } else {
848        0
849    };
850    r.d[3] = if shift < 320 {
851        l[3 + shiftlimbs] >> shiftlow
852    } else {
853        0
854    };
855    let bit = (l[(shift - 1) as usize >> 6] >> ((shift - 1) & 0x3F)) & 1;
856    scalar_cadd_bit(r, 0, bit != 0);
857}
858
859fn scalar_cadd_bit(r: &mut Scalar, bit: u32, flag: bool) {
860    let bit = if flag { bit } else { bit + 256 };
861    if bit >= 256 {
862        return;
863    }
864    let mut t: u128 = r.d[0] as u128
865        + if (bit >> 6) == 0 {
866            1u128 << (bit & 0x3F)
867        } else {
868            0
869        };
870    r.d[0] = t as u64;
871    t >>= 64;
872    t += r.d[1] as u128
873        + if (bit >> 6) == 1 {
874            1u128 << (bit & 0x3F)
875        } else {
876            0
877        };
878    r.d[1] = t as u64;
879    t >>= 64;
880    t += r.d[2] as u128
881        + if (bit >> 6) == 2 {
882            1u128 << (bit & 0x3F)
883        } else {
884            0
885        };
886    r.d[2] = t as u64;
887    t >>= 64;
888    t += r.d[3] as u128
889        + if (bit >> 6) == 3 {
890            1u128 << (bit & 0x3F)
891        } else {
892            0
893        };
894    r.d[3] = t as u64;
895}
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900
901    #[test]
902    fn test_split_lambda_identity() {
903        // r1 + lambda*r2 == k (mod n)
904        let mut k = Scalar::zero();
905        k.set_int(42);
906
907        let mut r1 = Scalar::zero();
908        let mut r2 = Scalar::zero();
909        Scalar::split_lambda(&mut r1, &mut r2, &k);
910
911        let mut lambda_r2 = Scalar::zero();
912        lambda_r2.mul(&r2, &LAMBDA);
913        let mut check = Scalar::zero();
914        check.add(&r1, &lambda_r2);
915        assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
916    }
917
918    #[test]
919    fn test_split_lambda_neg_three() {
920        let mut three = Scalar::zero();
921        three.set_int(3);
922        let mut k = Scalar::zero();
923        k.negate(&three); // k = -3 mod n
924
925        let mut r1 = Scalar::zero();
926        let mut r2 = Scalar::zero();
927        Scalar::split_lambda(&mut r1, &mut r2, &k);
928
929        let mut lambda_r2 = Scalar::zero();
930        lambda_r2.mul(&r2, &LAMBDA);
931        let mut check = Scalar::zero();
932        check.add(&r1, &lambda_r2);
933        assert!(
934            bool::from(check.ct_eq(&k)),
935            "r1 + lambda*r2 should equal k for k=-3"
936        );
937    }
938
939    #[test]
940    fn test_split_lambda_ecdsa_scalar() {
941        let mut k = Scalar::zero();
942        k.d = [
943            11125243483441707226,
944            2149109665766520832,
945            14302025600096445326,
946            4162584031737161978,
947        ];
948
949        let n_big = scalar_to_biguint(&N);
950
951        let mut r1 = Scalar::zero();
952        let mut r2 = Scalar::zero();
953        Scalar::split_lambda(&mut r1, &mut r2, &k);
954
955        let r1_big = scalar_to_biguint(&r1);
956        let r2_big = scalar_to_biguint(&r2);
957        let n_half = &n_big / BigUint::from(2u64);
958        let r1_abs = if r1_big > n_half {
959            &n_big - &r1_big
960        } else {
961            r1_big.clone()
962        };
963        let r2_abs = if r2_big > n_half {
964            &n_big - &r2_big
965        } else {
966            r2_big.clone()
967        };
968        assert!(
969            r1_abs.bits() <= 128,
970            "|r1| exceeds 128 bits: {}",
971            r1_abs.bits()
972        );
973        assert!(
974            r2_abs.bits() <= 128,
975            "|r2| exceeds 128 bits: {}",
976            r2_abs.bits()
977        );
978
979        let mut lambda_r2 = Scalar::zero();
980        lambda_r2.mul(&r2, &LAMBDA);
981        let mut check = Scalar::zero();
982        check.add(&r1, &lambda_r2);
983        assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
984    }
985
986    #[test]
987    fn test_split_128_identity() {
988        // r1 + 2^128*r2 == k
989        let mut k = Scalar::zero();
990        k.set_int(0x1234_5678);
991
992        let mut r1 = Scalar::zero();
993        let mut r2 = Scalar::zero();
994        Scalar::split_128(&mut r1, &mut r2, &k);
995
996        let mut two_128 = Scalar::zero();
997        two_128.d[2] = 1;
998        let mut r2_shifted = Scalar::zero();
999        r2_shifted.mul(&r2, &two_128);
1000        let mut check = Scalar::zero();
1001        check.add(&r1, &r2_shifted);
1002        assert!(bool::from(check.ct_eq(&k)), "r1 + 2^128*r2 should equal k");
1003    }
1004}