Skip to main content

blvm_secp256k1/
scalar.rs

1//! Scalar arithmetic modulo the secp256k1 group order n.
2//!
3//! n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
4//!
5//! **Timing:** [`Scalar::inv`] is constant-time on `x86_64` and `aarch64` (modular inverse),
6//! and falls back to Fermat on other targets. [`Scalar::inv_var`] is the same as [`inv`](Scalar::inv).
7
8#[cfg(target_arch = "x86_64")]
9mod scalar_asm {
10    use super::Scalar;
11
12    extern "C" {
13        /// libsecp256k1 scalar_mul_512: l8 = a * b (512-bit product).
14        /// SysV: rdi=l8, rsi=a, rdx=b.
15        fn blvm_secp256k1_scalar_mul_512(l8: *mut u64, a: *const Scalar, b: *const Scalar);
16
17        /// libsecp256k1 scalar_reduce_512: reduce 512-bit l mod n into r.
18        /// Returns overflow for final reduction. SysV: rdi=r, rsi=l.
19        fn blvm_secp256k1_scalar_reduce_512(r: *mut Scalar, l: *const u64) -> u64;
20    }
21
22    #[inline(always)]
23    pub(super) unsafe fn scalar_mul_512_asm(l: *mut u64, a: *const Scalar, b: *const Scalar) {
24        blvm_secp256k1_scalar_mul_512(l, a, b);
25    }
26
27    #[inline(always)]
28    pub(super) unsafe fn scalar_reduce_512_asm(r: *mut Scalar, l: *const u64) -> u64 {
29        blvm_secp256k1_scalar_reduce_512(r, l)
30    }
31}
32
33use num_bigint::BigUint;
34use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
35
36/// Scalar modulo group order n. 4x64 limb layout (x86_64, aarch64).
37#[repr(C)]
38#[derive(Clone, Copy, Debug)]
39pub struct Scalar {
40    pub d: [u64; 4],
41}
42
43// secp256k1 group order n
44const N_0: u64 = 0xBFD25E8CD0364141;
45const N_1: u64 = 0xBAAEDCE6AF48A03B;
46const N_2: u64 = 0xFFFFFFFFFFFFFFFE;
47const N_3: u64 = 0xFFFFFFFFFFFFFFFF;
48
49// 2^256 - n (for reduction)
50const N_C_0: u64 = 0x402DA1732FC9BEBF;
51const N_C_1: u64 = 0x4551231950B75FC4;
52const N_C_2: u64 = 1;
53
54// n/2 (for is_high)
55const N_H_0: u64 = 0xDFE92F46681B20A0;
56const N_H_1: u64 = 0x5D576E7357A4501D;
57const N_H_2: u64 = 0xFFFFFFFFFFFFFFFF;
58const N_H_3: u64 = 0x7FFFFFFFFFFFFFFF;
59
60// modulus n (for safegcd inv)
61#[allow(dead_code)]
62const N: Scalar = Scalar {
63    d: [N_0, N_1, N_2, N_3],
64};
65
66const LAMBDA: Scalar = Scalar {
67    d: [
68        0xDF02967C1B23BD72,
69        0x122E22EA20816678,
70        0xA5261C028812645A,
71        0x5363AD4CC05C30E0,
72    ],
73};
74
75impl Scalar {
76    pub fn zero() -> Self {
77        Self { d: [0, 0, 0, 0] }
78    }
79
80    pub fn one() -> Self {
81        Self { d: [1, 0, 0, 0] }
82    }
83
84    pub fn set_int(&mut self, v: u32) {
85        self.d[0] = v as u64;
86        self.d[1] = 0;
87        self.d[2] = 0;
88        self.d[3] = 0;
89    }
90
91    /// Set from 32-byte big-endian. Reduces mod n.
92    pub fn set_b32(&mut self, bin: &[u8; 32]) -> bool {
93        self.d[0] = read_be64(&bin[24..32]);
94        self.d[1] = read_be64(&bin[16..24]);
95        self.d[2] = read_be64(&bin[8..16]);
96        self.d[3] = read_be64(&bin[0..8]);
97        let overflow = self.check_overflow();
98        self.reduce(overflow as u64);
99        overflow
100    }
101
102    pub fn get_b32(&self, bin: &mut [u8; 32]) {
103        write_be64(&mut bin[0..8], self.d[3]);
104        write_be64(&mut bin[8..16], self.d[2]);
105        write_be64(&mut bin[16..24], self.d[1]);
106        write_be64(&mut bin[24..32], self.d[0]);
107    }
108
109    fn check_overflow(&self) -> bool {
110        let mut yes = 0u64;
111        let mut no = 0u64;
112        no |= (self.d[3] < N_3) as u64;
113        no |= (self.d[2] < N_2) as u64;
114        yes |= (self.d[2] > N_2) as u64 & !no;
115        no |= (self.d[1] < N_1) as u64;
116        yes |= (self.d[1] > N_1) as u64 & !no;
117        yes |= (self.d[0] >= N_0) as u64 & !no;
118        yes != 0
119    }
120
121    fn reduce(&mut self, overflow: u64) {
122        let mut t: u128 = self.d[0] as u128 + (overflow as u128 * N_C_0 as u128);
123        self.d[0] = t as u64;
124        t >>= 64;
125        t += self.d[1] as u128 + (overflow as u128 * N_C_1 as u128);
126        self.d[1] = t as u64;
127        t >>= 64;
128        t += self.d[2] as u128 + (overflow as u128 * N_C_2 as u128);
129        self.d[2] = t as u64;
130        t >>= 64;
131        t += self.d[3] as u128;
132        self.d[3] = t as u64;
133    }
134
135    pub fn is_zero(&self) -> bool {
136        (self.d[0] | self.d[1] | self.d[2] | self.d[3]) == 0
137    }
138
139    pub fn is_one(&self) -> bool {
140        (self.d[0] ^ 1) | self.d[1] | self.d[2] | self.d[3] == 0
141    }
142
143    /// True if scalar is odd (d[0] & 1).
144    pub fn is_odd(&self) -> bool {
145        self.d[0] & 1 != 0
146    }
147
148    /// True if scalar is even. Used by wnaf_fixed (Pippenger).
149    pub(crate) fn is_even(&self) -> bool {
150        self.d[0] & 1 == 0
151    }
152
153    /// self = a - b (mod n). Result in [0, n-1].
154    #[allow(dead_code)]
155    fn sub(&mut self, a: &Scalar, b: &Scalar) {
156        let mut neg_b = Scalar::zero();
157        neg_b.negate(b);
158        self.add(a, &neg_b);
159    }
160
161    /// self = self / 2. Only valid when self is even.
162    #[allow(dead_code)]
163    fn half(&mut self) {
164        self.d[0] = (self.d[0] >> 1) | (self.d[1] << 63);
165        self.d[1] = (self.d[1] >> 1) | (self.d[2] << 63);
166        self.d[2] = (self.d[2] >> 1) | (self.d[3] << 63);
167        self.d[3] >>= 1;
168    }
169
170    /// self = (self + n) / 2. Only valid when self is odd. Result in [0, n-1].
171    #[allow(dead_code)]
172    fn half_add_n(&mut self) {
173        let mut t: u128 = self.d[0] as u128 + N_0 as u128;
174        let c0 = t as u64;
175        let mut c1 = (t >> 64) as u64;
176        t = self.d[1] as u128 + N_1 as u128 + c1 as u128;
177        c1 = t as u64;
178        let mut c2 = (t >> 64) as u64;
179        t = self.d[2] as u128 + N_2 as u128 + c2 as u128;
180        c2 = t as u64;
181        let mut c3 = (t >> 64) as u64;
182        t = self.d[3] as u128 + N_3 as u128 + c3 as u128;
183        c3 = t as u64;
184        let c4 = (t >> 64) as u64;
185        self.d[0] = (c0 >> 1) | (c1 << 63);
186        self.d[1] = (c1 >> 1) | (c2 << 63);
187        self.d[2] = (c2 >> 1) | (c3 << 63);
188        self.d[3] = (c3 >> 1) | (c4 << 63);
189        self.reduce(self.check_overflow() as u64);
190    }
191
192    /// div2(M, x): x/2 mod n when x even, (x+n)/2 mod n when x odd.
193    ///
194    /// Constant-time: no branch on the scalar's parity bit. Builds a mask from
195    /// the low bit and adds `N * mask` before shifting, so both halves of the
196    /// original `if/else` execute in the same instruction stream.
197    pub fn div2(&mut self) {
198        // add_mask = u64::MAX when odd, 0 when even — no branch.
199        let add_mask = 0u64.wrapping_sub(self.d[0] & 1);
200        let mut t: u128 = self.d[0] as u128 + (N_0 & add_mask) as u128;
201        let c0 = t as u64;
202        t >>= 64;
203        t += self.d[1] as u128 + (N_1 & add_mask) as u128;
204        let c1 = t as u64;
205        t >>= 64;
206        t += self.d[2] as u128 + (N_2 & add_mask) as u128;
207        let c2 = t as u64;
208        t >>= 64;
209        t += self.d[3] as u128 + (N_3 & add_mask) as u128;
210        let c3 = t as u64;
211        let c4 = (t >> 64) as u64;
212        self.d[0] = (c0 >> 1) | (c1 << 63);
213        self.d[1] = (c1 >> 1) | (c2 << 63);
214        self.d[2] = (c2 >> 1) | (c3 << 63);
215        self.d[3] = (c3 >> 1) | (c4 << 63);
216        // For any valid scalar in [0, N-1] the result is already < N; reduce is a no-op
217        // (overflow = 0) but kept as a safety net — reduce(0) touches no data.
218        self.reduce(self.check_overflow() as u64);
219    }
220
221    /// r = a/2 (mod n). Same as libsecp256k1 `scalar_half` (used by `ecmult_const`).
222    pub fn half_modn(&mut self, a: &Scalar) {
223        *self = *a;
224        self.div2();
225    }
226
227    /// tmp = a + b (full 257-bit add, no reduction). Used when both are odd and we need (a+b)/2.
228    #[allow(dead_code)]
229    fn add_no_reduce(a: &Scalar, b: &Scalar) -> [u64; 5] {
230        let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
231        let c0 = t as u64;
232        let mut c1 = (t >> 64) as u64;
233        t = a.d[1] as u128 + b.d[1] as u128 + c1 as u128;
234        c1 = t as u64;
235        let mut c2 = (t >> 64) as u64;
236        t = a.d[2] as u128 + b.d[2] as u128 + c2 as u128;
237        c2 = t as u64;
238        let mut c3 = (t >> 64) as u64;
239        t = a.d[3] as u128 + b.d[3] as u128 + c3 as u128;
240        c3 = t as u64;
241        let c4 = (t >> 64) as u64;
242        [c0, c1, c2, c3, c4]
243    }
244
245    /// self = (c0..c4) >> 1, then reduce mod n.
246    #[allow(dead_code)]
247    fn set_from_5limb_half(&mut self, c: &[u64; 5]) {
248        self.d[0] = (c[0] >> 1) | (c[1] << 63);
249        self.d[1] = (c[1] >> 1) | (c[2] << 63);
250        self.d[2] = (c[2] >> 1) | (c[3] << 63);
251        self.d[3] = (c[3] >> 1) | (c[4] << 63);
252        self.reduce(self.check_overflow() as u64);
253    }
254
255    /// self = (a - b) >> 1 mod n. a and b odd. When a>=b, a-b is even; when a<b, a-b+n is odd, div2 adds n.
256    #[allow(dead_code)]
257    fn sub_half(&mut self, a: &Scalar, b: &Scalar) {
258        self.sub(a, b);
259        self.div2();
260    }
261
262    pub fn add(&mut self, a: &Scalar, b: &Scalar) -> bool {
263        let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
264        self.d[0] = t as u64;
265        t >>= 64;
266        t += a.d[1] as u128 + b.d[1] as u128;
267        self.d[1] = t as u64;
268        t >>= 64;
269        t += a.d[2] as u128 + b.d[2] as u128;
270        self.d[2] = t as u64;
271        t >>= 64;
272        t += a.d[3] as u128 + b.d[3] as u128;
273        self.d[3] = t as u64;
274        t >>= 64;
275        let overflow = t as u64 + self.check_overflow() as u64;
276        debug_assert!(overflow <= 1);
277        self.reduce(overflow);
278        overflow != 0
279    }
280
281    pub fn negate(&mut self, a: &Scalar) {
282        // Branchless mask: u64::MAX iff a != 0 (negate is a no-op on zero in mod-n sense).
283        let nz = a.d[0] | a.d[1] | a.d[2] | a.d[3];
284        let nonzero = 0u64.wrapping_sub((nz != 0) as u64);
285        let mut t: u128 = (!a.d[0]) as u128 + (N_0 + 1) as u128;
286        self.d[0] = (t as u64) & nonzero;
287        t >>= 64;
288        t += (!a.d[1]) as u128 + N_1 as u128;
289        self.d[1] = (t as u64) & nonzero;
290        t >>= 64;
291        t += (!a.d[2]) as u128 + N_2 as u128;
292        self.d[2] = (t as u64) & nonzero;
293        t >>= 64;
294        t += (!a.d[3]) as u128 + N_3 as u128;
295        self.d[3] = (t as u64) & nonzero;
296    }
297
298    pub fn mul(&mut self, a: &Scalar, b: &Scalar) {
299        let mut l = [0u64; 8];
300        scalar_mul_512(&mut l, a, b);
301        scalar_reduce_512(self, &l);
302    }
303
304    /// split_lambda: find r1, r2 such that r1 + r2*lambda == k (mod n)
305    pub fn split_lambda(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
306        const MINUS_B1: Scalar = Scalar {
307            d: [
308                (0x6F547FA9u64 << 32) | 0x0ABFE4C3,
309                (0xE4437ED6u64 << 32) | 0x010E8828,
310                0,
311                0,
312            ],
313        };
314        const MINUS_B2: Scalar = Scalar {
315            d: [
316                (0xD765CDA8u64 << 32) | 0x3DB1562C,
317                (0x8A280AC5u64 << 32) | 0x0774346D,
318                (0xFFFFFFFFu64 << 32) | 0xFFFFFFFE,
319                (0xFFFFFFFFu64 << 32) | 0xFFFFFFFF,
320            ],
321        };
322        const G1: Scalar = Scalar {
323            d: [
324                (0xE893209Au64 << 32) | 0x45DBB031,
325                (0x3DAA8A14u64 << 32) | 0x71E8CA7F,
326                (0xE86C90E4u64 << 32) | 0x9284EB15,
327                (0x3086D221u64 << 32) | 0xA7D46BCD,
328            ],
329        };
330        const G2: Scalar = Scalar {
331            d: [
332                (0x1571B4AEu64 << 32) | 0x8AC47F71,
333                (0x221208ACu64 << 32) | 0x9DF506C6,
334                (0x6F547FA9u64 << 32) | 0x0ABFE4C4,
335                (0xE4437ED6u64 << 32) | 0x010E8828,
336            ],
337        };
338
339        let mut c1 = Scalar::zero();
340        let mut c2 = Scalar::zero();
341        scalar_mul_shift_var(&mut c1, k, &G1, 384);
342        scalar_mul_shift_var(&mut c2, k, &G2, 384);
343        let mut t = Scalar::zero();
344        t.mul(&c1, &MINUS_B1);
345        c1 = t;
346        t.mul(&c2, &MINUS_B2);
347        c2 = t;
348        r2.add(&c1, &c2);
349        r1.mul(r2, &LAMBDA);
350        let mut neg = Scalar::zero();
351        neg.negate(r1);
352        r1.add(&neg, k);
353    }
354
355    /// Extract `count` bits at `offset` (0..256). Count in [1,32].
356    /// For single-limb (offset and offset+count-1 in same u64) uses fast path.
357    pub fn get_bits_limb32(&self, offset: u32, count: u32) -> u32 {
358        debug_assert!(count > 0 && count <= 32);
359        debug_assert!((offset + count - 1) >> 6 == offset >> 6);
360        let limb = offset >> 6;
361        let shift = offset & 0x3F;
362        let mask = if count == 32 {
363            u32::MAX
364        } else {
365            (1u32 << count) - 1
366        };
367        ((self.d[limb as usize] >> shift) as u32) & mask
368    }
369
370    /// Extract `count` bits at `offset`. Count in [1,32], offset+count <= 256.
371    pub fn get_bits_var(&self, offset: u32, count: u32) -> u32 {
372        debug_assert!(count > 0 && count <= 32);
373        debug_assert!(offset + count <= 256);
374        if (offset + count - 1) >> 6 == offset >> 6 {
375            self.get_bits_limb32(offset, count)
376        } else {
377            let limb = (offset >> 6) as usize;
378            let shift = offset & 0x3F;
379            let mask = if count == 32 {
380                u32::MAX
381            } else {
382                (1u32 << count) - 1
383            };
384            let lo = self.d[limb] >> shift;
385            let hi = self.d[limb + 1].wrapping_shl(64u32 - shift);
386            ((lo | hi) as u32) & mask
387        }
388    }
389
390    pub fn split_128(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
391        r1.d[0] = k.d[0];
392        r1.d[1] = k.d[1];
393        r1.d[2] = 0;
394        r1.d[3] = 0;
395        r2.d[0] = k.d[2];
396        r2.d[1] = k.d[3];
397        r2.d[2] = 0;
398        r2.d[3] = 0;
399    }
400
401    /// Modular inverse r = a^(-1) (mod n). If a is zero, r is zero.
402    /// On `x86_64` and `aarch64` this uses the same safegcd / `modinv64` path as libsecp256k1.
403    /// On other targets it falls back to Fermat; treat those targets as not secret-safe for timing.
404    pub fn inv(&mut self, a: &Scalar) {
405        if a.is_zero() {
406            *self = Scalar::zero();
407            return;
408        }
409        #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
410        {
411            use crate::modinv64::{modinv64, SECP256K1_SCALAR_MODINV_MODINFO};
412            let mut x = scalar_to_signed62(a);
413            modinv64(&mut x, &SECP256K1_SCALAR_MODINV_MODINFO);
414            scalar_from_signed62(self, &x);
415        }
416        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
417        {
418            let a_big = scalar_to_biguint(a);
419            let n_big = scalar_to_biguint(&N);
420            let exp = &n_big - 2u32;
421            let inv_big = a_big.modpow(&exp, &n_big);
422            biguint_to_scalar(self, &inv_big);
423        }
424    }
425
426    /// Backwards compatibility: same as [`Self::inv`].
427    pub fn inv_var(&mut self, a: &Scalar) {
428        self.inv(a);
429    }
430
431    /// True if scalar is in the upper half [n/2, n).
432    pub fn is_high(&self) -> bool {
433        let mut yes = 0u64;
434        let mut no = 0u64;
435        no |= (self.d[3] < N_H_3) as u64;
436        yes |= (self.d[3] > N_H_3) as u64 & !no;
437        no |= (self.d[2] < N_H_2) as u64 & !yes;
438        no |= (self.d[1] < N_H_1) as u64 & !yes;
439        yes |= (self.d[1] > N_H_1) as u64 & !no;
440        yes |= (self.d[0] > N_H_0) as u64 & !no;
441        yes != 0
442    }
443
444    /// Conditionally negate: if flag != 0, negate in place. Returns 1 if negated, -1 if not.
445    ///
446    /// Constant-time: mask and nonzero are constructed without branches; the
447    /// return value is computed via branchless sign-extension of the mask's MSB.
448    pub fn cond_negate(&mut self, flag: i32) -> i32 {
449        // Build masks without branches.
450        let mask = 0u64.wrapping_sub((flag != 0) as u64); // 0 or u64::MAX
451        let nonzero = 0u64.wrapping_sub((!self.is_zero()) as u64); // 0 or u64::MAX
452        let mut t: u128 = (self.d[0] ^ mask) as u128;
453        t += ((N_0 + 1) & mask) as u128;
454        self.d[0] = (t as u64) & nonzero;
455        t >>= 64;
456        t += (self.d[1] ^ mask) as u128;
457        t += (N_1 & mask) as u128;
458        self.d[1] = (t as u64) & nonzero;
459        t >>= 64;
460        t += (self.d[2] ^ mask) as u128;
461        t += (N_2 & mask) as u128;
462        self.d[2] = (t as u64) & nonzero;
463        t >>= 64;
464        t += (self.d[3] ^ mask) as u128;
465        t += (N_3 & mask) as u128;
466        self.d[3] = (t as u64) & nonzero;
467        // Return 1 if negated, -1 if not — branchless via mask MSB.
468        // mask=0   → (0>>63)=0  → 0*2-1 = -1
469        // mask=MAX → (MAX>>63)=1 → 1*2-1 =  1
470        ((mask >> 63) as i32) * 2 - 1
471    }
472}
473
474impl ConstantTimeEq for Scalar {
475    fn ct_eq(&self, other: &Self) -> Choice {
476        self.d[0].ct_eq(&other.d[0])
477            & self.d[1].ct_eq(&other.d[1])
478            & self.d[2].ct_eq(&other.d[2])
479            & self.d[3].ct_eq(&other.d[3])
480    }
481}
482
483impl ConditionallySelectable for Scalar {
484    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
485        Self {
486            d: [
487                u64::conditional_select(&a.d[0], &b.d[0], choice),
488                u64::conditional_select(&a.d[1], &b.d[1], choice),
489                u64::conditional_select(&a.d[2], &b.d[2], choice),
490                u64::conditional_select(&a.d[3], &b.d[3], choice),
491            ],
492        }
493    }
494}
495
496/// Pack 4×64 scalar limbs into 5×62 signed limbs for modinv64.
497#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
498fn scalar_to_signed62(a: &Scalar) -> crate::modinv64::Signed62 {
499    const M62: u64 = u64::MAX >> 2;
500    let d = &a.d;
501    crate::modinv64::Signed62 {
502        v: [
503            (d[0] & M62) as i64,
504            ((d[0] >> 62 | d[1] << 2) & M62) as i64,
505            ((d[1] >> 60 | d[2] << 4) & M62) as i64,
506            ((d[2] >> 58 | d[3] << 6) & M62) as i64,
507            (d[3] >> 56) as i64,
508        ],
509    }
510}
511
512/// Unpack 5×62 signed limbs back to 4×64 scalar limbs.
513#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
514fn scalar_from_signed62(r: &mut Scalar, a: &crate::modinv64::Signed62) {
515    let v = &a.v;
516    r.d[0] = (v[0] as u64) | ((v[1] as u64) << 62);
517    r.d[1] = ((v[1] as u64) >> 2) | ((v[2] as u64) << 60);
518    r.d[2] = ((v[2] as u64) >> 4) | ((v[3] as u64) << 58);
519    r.d[3] = ((v[3] as u64) >> 6) | ((v[4] as u64) << 56);
520}
521
522#[allow(dead_code)]
523fn scalar_to_biguint(s: &Scalar) -> BigUint {
524    let mut bytes = [0u8; 32];
525    s.get_b32(&mut bytes);
526    BigUint::from_bytes_be(&bytes)
527}
528
529#[allow(dead_code)]
530fn biguint_to_scalar(r: &mut Scalar, b: &BigUint) {
531    let bytes = b.to_bytes_be();
532    let mut buf = [0u8; 32];
533    let len = bytes.len().min(32);
534    let start = 32 - len;
535    buf[start..].copy_from_slice(&bytes[..len]);
536    r.set_b32(&buf);
537}
538
539fn read_be64(b: &[u8]) -> u64 {
540    u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
541}
542
543fn write_be64(b: &mut [u8], v: u64) {
544    b[0..8].copy_from_slice(&v.to_be_bytes());
545}
546
547fn scalar_mul_512(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
548    #[cfg(target_arch = "x86_64")]
549    {
550        unsafe {
551            scalar_asm::scalar_mul_512_asm(l.as_mut_ptr(), a, b);
552        }
553    }
554    #[cfg(not(target_arch = "x86_64"))]
555    {
556        scalar_mul_512_rust(l, a, b);
557    }
558}
559
560#[cfg(not(target_arch = "x86_64"))]
561fn scalar_mul_512_rust(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
562    let mut c0: u64 = 0;
563    let mut c1: u64 = 0;
564    let mut c2: u32 = 0;
565
566    macro_rules! muladd_fast {
567        ($a:expr, $b:expr) => {{
568            let prod = ($a as u128) * ($b as u128);
569            let prod_lo = prod as u64;
570            let prod_hi = (prod >> 64) as u64;
571            let (lo, o) = c0.overflowing_add(prod_lo);
572            c0 = lo;
573            c1 += prod_hi + o as u64; // C: th = prod_hi + (c0 < tl)
574        }};
575    }
576    macro_rules! muladd {
577        ($a:expr, $b:expr) => {{
578            let prod = ($a as u128) * ($b as u128);
579            let hi = (prod >> 64) as u64;
580            let (lo, o1) = c0.overflowing_add(prod as u64);
581            c0 = lo;
582            let th = hi + o1 as u64;
583            let (mid, o2) = c1.overflowing_add(th);
584            c1 = mid;
585            c2 += o2 as u32;
586        }};
587    }
588    macro_rules! sumadd {
589        ($a:expr) => {{
590            let (lo, o) = c0.overflowing_add($a);
591            c0 = lo;
592            c1 += o as u64;
593            c2 += (c1 == 0 && o) as u32;
594        }};
595    }
596    macro_rules! extract {
597        () => {{
598            let n = c0;
599            c0 = c1;
600            c1 = c2 as u64;
601            c2 = 0;
602            n
603        }};
604    }
605    macro_rules! extract_fast {
606        () => {{
607            let n = c0;
608            c0 = c1;
609            c1 = 0;
610            n
611        }};
612    }
613
614    muladd_fast!(a.d[0], b.d[0]);
615    l[0] = extract_fast!();
616    muladd!(a.d[0], b.d[1]);
617    muladd!(a.d[1], b.d[0]);
618    l[1] = extract!();
619    muladd!(a.d[0], b.d[2]);
620    muladd!(a.d[1], b.d[1]);
621    muladd!(a.d[2], b.d[0]);
622    l[2] = extract!();
623    muladd!(a.d[0], b.d[3]);
624    muladd!(a.d[1], b.d[2]);
625    muladd!(a.d[2], b.d[1]);
626    muladd!(a.d[3], b.d[0]);
627    l[3] = extract!();
628    muladd!(a.d[1], b.d[3]);
629    muladd!(a.d[2], b.d[2]);
630    muladd!(a.d[3], b.d[1]);
631    l[4] = extract!();
632    muladd!(a.d[2], b.d[3]);
633    muladd!(a.d[3], b.d[2]);
634    l[5] = extract!();
635    muladd_fast!(a.d[3], b.d[3]);
636    l[6] = extract_fast!();
637    l[7] = c0;
638}
639
640#[allow(dead_code)]
641fn limbs_512_to_biguint(l: &[u64; 8]) -> BigUint {
642    let mut acc = BigUint::from(0u64);
643    for (i, &limb) in l.iter().enumerate() {
644        acc += BigUint::from(limb) << (64 * i);
645    }
646    acc
647}
648
649/// Limb-based 512→256 reduction mod n. Replaces BigUint for hot path.
650/// Port of libsecp256k1 scalar_reduce_512 C fallback (muladd/extract).
651#[cfg(not(target_arch = "x86_64"))]
652fn scalar_reduce_512_limbs(r: &mut Scalar, l: &[u64; 8]) {
653    let n0 = l[4];
654    let n1 = l[5];
655    let n2 = l[6];
656    let n3 = l[7];
657
658    let mut c0: u64 = l[0];
659    let mut c1: u64 = 0;
660    let mut c2: u32 = 0;
661
662    macro_rules! muladd_fast {
663        ($a:expr, $b:expr) => {{
664            let prod = ($a as u128) * ($b as u128);
665            let (lo, o) = c0.overflowing_add(prod as u64);
666            c0 = lo;
667            c1 += (prod >> 64) as u64 + o as u64;
668        }};
669    }
670    macro_rules! muladd {
671        ($a:expr, $b:expr) => {{
672            let prod = ($a as u128) * ($b as u128);
673            let (lo, o1) = c0.overflowing_add(prod as u64);
674            c0 = lo;
675            let th = (prod >> 64) as u64 + o1 as u64;
676            let (mid, o2) = c1.overflowing_add(th);
677            c1 = mid;
678            c2 += o2 as u32;
679        }};
680    }
681    macro_rules! sumadd_fast {
682        ($a:expr) => {{
683            let (lo, o) = c0.overflowing_add($a);
684            c0 = lo;
685            c1 += o as u64;
686        }};
687    }
688    macro_rules! sumadd {
689        ($a:expr) => {{
690            let (lo, o) = c0.overflowing_add($a);
691            c0 = lo;
692            let (mid, o2) = c1.overflowing_add(o as u64);
693            c1 = mid;
694            c2 += o2 as u32;
695        }};
696    }
697    macro_rules! extract {
698        () => {{
699            let n = c0;
700            c0 = c1;
701            c1 = c2 as u64;
702            c2 = 0;
703            n
704        }};
705    }
706    macro_rules! extract_fast {
707        () => {{
708            let n = c0;
709            c0 = c1;
710            c1 = 0;
711            n
712        }};
713    }
714
715    // Reduce 512 bits into 385: m[0..6] = l[0..3] + n[0..3] * N_C
716    muladd_fast!(n0, N_C_0);
717    let m0 = extract_fast!();
718    sumadd_fast!(l[1]);
719    muladd!(n1, N_C_0);
720    muladd!(n0, N_C_1);
721    let m1 = extract!();
722    sumadd!(l[2]);
723    muladd!(n2, N_C_0);
724    muladd!(n1, N_C_1);
725    sumadd!(n0);
726    let m2 = extract!();
727    sumadd!(l[3]);
728    muladd!(n3, N_C_0);
729    muladd!(n2, N_C_1);
730    sumadd!(n1);
731    let m3 = extract!();
732    muladd!(n3, N_C_1);
733    sumadd!(n2);
734    let m4 = extract!();
735    sumadd_fast!(n3);
736    let m5 = extract_fast!();
737    let m6 = c0 as u32;
738
739    // Reduce 385 into 258: p[0..4] = m[0..3] + m[4..6] * N_C
740    c0 = m0;
741    c1 = 0;
742    c2 = 0;
743    muladd_fast!(m4, N_C_0);
744    let p0 = extract_fast!();
745    sumadd_fast!(m1);
746    muladd!(m5, N_C_0);
747    muladd!(m4, N_C_1);
748    let p1 = extract!();
749    sumadd!(m2);
750    muladd!(m6 as u64, N_C_0);
751    muladd!(m5, N_C_1);
752    sumadd!(m4);
753    let p2 = extract!();
754    sumadd_fast!(m3);
755    muladd_fast!(m6 as u64, N_C_1);
756    sumadd_fast!(m5);
757    let p3 = extract_fast!();
758    let p4 = (c0 + m6 as u64) as u32;
759
760    // Reduce 258 into 256: r = p[0..3] + p4 * N_C
761    let mut t: u128 = p0 as u128;
762    t += (N_C_0 as u128) * (p4 as u128);
763    r.d[0] = t as u64;
764    t >>= 64;
765    t += p1 as u128;
766    t += (N_C_1 as u128) * (p4 as u128);
767    r.d[1] = t as u64;
768    t >>= 64;
769    t += p2 as u128;
770    t += p4 as u128;
771    r.d[2] = t as u64;
772    t >>= 64;
773    t += p3 as u128;
774    r.d[3] = t as u64;
775    let c = (t >> 64) as u64;
776
777    // Final reduction
778    scalar_reduce(r, c + scalar_check_overflow(r));
779}
780
781/// Add overflow*N_C to r. overflow is 0 or 1.
782fn scalar_reduce(r: &mut Scalar, overflow: u64) {
783    let of = overflow as u128;
784    let mut t: u128 = r.d[0] as u128;
785    t += of * (N_C_0 as u128);
786    r.d[0] = t as u64;
787    t >>= 64;
788    t += r.d[1] as u128;
789    t += of * (N_C_1 as u128);
790    r.d[1] = t as u64;
791    t >>= 64;
792    t += r.d[2] as u128;
793    t += of * (N_C_2 as u128);
794    r.d[2] = t as u64;
795    t >>= 64;
796    r.d[3] = (t as u64).wrapping_add(r.d[3]);
797}
798
799/// Returns 1 if r >= n, else 0.
800fn scalar_check_overflow(r: &Scalar) -> u64 {
801    let mut yes = 0u64;
802    let mut no = 0u64;
803    no |= (r.d[3] < N_3) as u64;
804    no |= (r.d[2] < N_2) as u64;
805    yes |= (r.d[2] > N_2) as u64 & !no;
806    no |= (r.d[1] < N_1) as u64;
807    yes |= (r.d[1] > N_1) as u64 & !no;
808    yes |= (r.d[0] >= N_0) as u64 & !no;
809    yes
810}
811
812fn scalar_reduce_512(r: &mut Scalar, l: &[u64; 8]) {
813    #[cfg(target_arch = "x86_64")]
814    {
815        let c = unsafe { scalar_asm::scalar_reduce_512_asm(r, l.as_ptr()) };
816        scalar_reduce(r, c + scalar_check_overflow(r));
817    }
818    #[cfg(not(target_arch = "x86_64"))]
819    {
820        scalar_reduce_512_limbs(r, l);
821    }
822}
823
824#[cfg(test)]
825#[test]
826fn test_scalar_reduce_n_plus_1() {
827    let l = [N_0 + 1, N_1, N_2, N_3, 0, 0, 0, 0];
828    let mut r = Scalar::zero();
829    scalar_reduce_512(&mut r, &l);
830    assert!(r.is_one(), "(n+1) mod n = 1, got r.d = {:?}", r.d);
831}
832
833#[cfg(test)]
834#[test]
835fn test_scalar_mul_inv2_times_2() {
836    let inv2_hex = "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1";
837    let inv2_bytes = hex::decode(inv2_hex).unwrap();
838    let mut buf = [0u8; 32];
839    buf.copy_from_slice(&inv2_bytes);
840    let mut inv2 = Scalar::zero();
841    inv2.set_b32(&buf);
842    let mut two = Scalar::zero();
843    two.set_int(2);
844    let mut l = [0u64; 8];
845    scalar_mul_512(&mut l, &inv2, &two);
846    let mut r = Scalar::zero();
847    scalar_reduce_512(&mut r, &l);
848    assert!(r.is_one(), "inv2*2 mod n = 1");
849}
850
851fn scalar_mul_shift_var(r: &mut Scalar, a: &Scalar, b: &Scalar, shift: u32) {
852    assert!(shift >= 256);
853    let mut l = [0u64; 8];
854    scalar_mul_512(&mut l, a, b);
855    let shiftlimbs = (shift >> 6) as usize;
856    let shiftlow = shift & 0x3F;
857    let shifthigh = 64 - shiftlow;
858    r.d[0] = if shift < 512 {
859        (l[shiftlimbs] >> shiftlow)
860            | (if shift < 448 && shiftlow != 0 {
861                l[1 + shiftlimbs] << shifthigh
862            } else {
863                0
864            })
865    } else {
866        0
867    };
868    r.d[1] = if shift < 448 {
869        (l[1 + shiftlimbs] >> shiftlow)
870            | (if shift < 384 && shiftlow != 0 {
871                l[2 + shiftlimbs] << shifthigh
872            } else {
873                0
874            })
875    } else {
876        0
877    };
878    r.d[2] = if shift < 384 {
879        (l[2 + shiftlimbs] >> shiftlow)
880            | (if shift < 320 && shiftlow != 0 {
881                l[3 + shiftlimbs] << shifthigh
882            } else {
883                0
884            })
885    } else {
886        0
887    };
888    r.d[3] = if shift < 320 {
889        l[3 + shiftlimbs] >> shiftlow
890    } else {
891        0
892    };
893    let bit = (l[(shift - 1) as usize >> 6] >> ((shift - 1) & 0x3F)) & 1;
894    scalar_cadd_bit(r, 0, bit != 0);
895}
896
897fn scalar_cadd_bit(r: &mut Scalar, bit: u32, flag: bool) {
898    let bit = if flag { bit } else { bit + 256 };
899    if bit >= 256 {
900        return;
901    }
902    let mut t: u128 = r.d[0] as u128
903        + if (bit >> 6) == 0 {
904            1u128 << (bit & 0x3F)
905        } else {
906            0
907        };
908    r.d[0] = t as u64;
909    t >>= 64;
910    t += r.d[1] as u128
911        + if (bit >> 6) == 1 {
912            1u128 << (bit & 0x3F)
913        } else {
914            0
915        };
916    r.d[1] = t as u64;
917    t >>= 64;
918    t += r.d[2] as u128
919        + if (bit >> 6) == 2 {
920            1u128 << (bit & 0x3F)
921        } else {
922            0
923        };
924    r.d[2] = t as u64;
925    t >>= 64;
926    t += r.d[3] as u128
927        + if (bit >> 6) == 3 {
928            1u128 << (bit & 0x3F)
929        } else {
930            0
931        };
932    r.d[3] = t as u64;
933}
934
935#[cfg(test)]
936mod tests {
937    use super::*;
938
939    #[test]
940    fn test_split_lambda_identity() {
941        // r1 + lambda*r2 == k (mod n)
942        let mut k = Scalar::zero();
943        k.set_int(42);
944
945        let mut r1 = Scalar::zero();
946        let mut r2 = Scalar::zero();
947        Scalar::split_lambda(&mut r1, &mut r2, &k);
948
949        let mut lambda_r2 = Scalar::zero();
950        lambda_r2.mul(&r2, &LAMBDA);
951        let mut check = Scalar::zero();
952        check.add(&r1, &lambda_r2);
953        assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
954    }
955
956    #[test]
957    fn test_split_lambda_neg_three() {
958        let mut three = Scalar::zero();
959        three.set_int(3);
960        let mut k = Scalar::zero();
961        k.negate(&three); // k = -3 mod n
962
963        let mut r1 = Scalar::zero();
964        let mut r2 = Scalar::zero();
965        Scalar::split_lambda(&mut r1, &mut r2, &k);
966
967        let mut lambda_r2 = Scalar::zero();
968        lambda_r2.mul(&r2, &LAMBDA);
969        let mut check = Scalar::zero();
970        check.add(&r1, &lambda_r2);
971        assert!(
972            bool::from(check.ct_eq(&k)),
973            "r1 + lambda*r2 should equal k for k=-3"
974        );
975    }
976
977    #[test]
978    fn test_split_lambda_ecdsa_scalar() {
979        let mut k = Scalar::zero();
980        k.d = [
981            11125243483441707226,
982            2149109665766520832,
983            14302025600096445326,
984            4162584031737161978,
985        ];
986
987        let n_big = scalar_to_biguint(&N);
988
989        let mut r1 = Scalar::zero();
990        let mut r2 = Scalar::zero();
991        Scalar::split_lambda(&mut r1, &mut r2, &k);
992
993        let r1_big = scalar_to_biguint(&r1);
994        let r2_big = scalar_to_biguint(&r2);
995        let n_half = &n_big / BigUint::from(2u64);
996        let r1_abs = if r1_big > n_half {
997            &n_big - &r1_big
998        } else {
999            r1_big.clone()
1000        };
1001        let r2_abs = if r2_big > n_half {
1002            &n_big - &r2_big
1003        } else {
1004            r2_big.clone()
1005        };
1006        assert!(
1007            r1_abs.bits() <= 128,
1008            "|r1| exceeds 128 bits: {}",
1009            r1_abs.bits()
1010        );
1011        assert!(
1012            r2_abs.bits() <= 128,
1013            "|r2| exceeds 128 bits: {}",
1014            r2_abs.bits()
1015        );
1016
1017        let mut lambda_r2 = Scalar::zero();
1018        lambda_r2.mul(&r2, &LAMBDA);
1019        let mut check = Scalar::zero();
1020        check.add(&r1, &lambda_r2);
1021        assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
1022    }
1023
1024    #[test]
1025    fn test_split_128_identity() {
1026        // r1 + 2^128*r2 == k
1027        let mut k = Scalar::zero();
1028        k.set_int(0x1234_5678);
1029
1030        let mut r1 = Scalar::zero();
1031        let mut r2 = Scalar::zero();
1032        Scalar::split_128(&mut r1, &mut r2, &k);
1033
1034        let mut two_128 = Scalar::zero();
1035        two_128.d[2] = 1;
1036        let mut r2_shifted = Scalar::zero();
1037        r2_shifted.mul(&r2, &two_128);
1038        let mut check = Scalar::zero();
1039        check.add(&r1, &r2_shifted);
1040        assert!(bool::from(check.ct_eq(&k)), "r1 + 2^128*r2 should equal k");
1041    }
1042}