1#[cfg(target_arch = "x86_64")]
9mod scalar_asm {
10 use super::Scalar;
11
12 extern "C" {
13 fn blvm_secp256k1_scalar_mul_512(l8: *mut u64, a: *const Scalar, b: *const Scalar);
16
17 fn blvm_secp256k1_scalar_reduce_512(r: *mut Scalar, l: *const u64) -> u64;
20 }
21
22 #[inline(always)]
23 pub(super) unsafe fn scalar_mul_512_asm(l: *mut u64, a: *const Scalar, b: *const Scalar) {
24 blvm_secp256k1_scalar_mul_512(l, a, b);
25 }
26
27 #[inline(always)]
28 pub(super) unsafe fn scalar_reduce_512_asm(r: *mut Scalar, l: *const u64) -> u64 {
29 blvm_secp256k1_scalar_reduce_512(r, l)
30 }
31}
32
33use num_bigint::BigUint;
34use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
35
36#[repr(C)]
38#[derive(Clone, Copy, Debug)]
39pub struct Scalar {
40 pub d: [u64; 4],
41}
42
43const N_0: u64 = 0xBFD25E8CD0364141;
45const N_1: u64 = 0xBAAEDCE6AF48A03B;
46const N_2: u64 = 0xFFFFFFFFFFFFFFFE;
47const N_3: u64 = 0xFFFFFFFFFFFFFFFF;
48
49const N_C_0: u64 = 0x402DA1732FC9BEBF;
51const N_C_1: u64 = 0x4551231950B75FC4;
52const N_C_2: u64 = 1;
53
54const N_H_0: u64 = 0xDFE92F46681B20A0;
56const N_H_1: u64 = 0x5D576E7357A4501D;
57const N_H_2: u64 = 0xFFFFFFFFFFFFFFFF;
58const N_H_3: u64 = 0x7FFFFFFFFFFFFFFF;
59
60#[allow(dead_code)]
62const N: Scalar = Scalar {
63 d: [N_0, N_1, N_2, N_3],
64};
65
66const LAMBDA: Scalar = Scalar {
67 d: [
68 0xDF02967C1B23BD72,
69 0x122E22EA20816678,
70 0xA5261C028812645A,
71 0x5363AD4CC05C30E0,
72 ],
73};
74
75impl Scalar {
76 pub fn zero() -> Self {
77 Self { d: [0, 0, 0, 0] }
78 }
79
80 pub fn one() -> Self {
81 Self { d: [1, 0, 0, 0] }
82 }
83
84 pub fn set_int(&mut self, v: u32) {
85 self.d[0] = v as u64;
86 self.d[1] = 0;
87 self.d[2] = 0;
88 self.d[3] = 0;
89 }
90
91 pub fn set_b32(&mut self, bin: &[u8; 32]) -> bool {
93 self.d[0] = read_be64(&bin[24..32]);
94 self.d[1] = read_be64(&bin[16..24]);
95 self.d[2] = read_be64(&bin[8..16]);
96 self.d[3] = read_be64(&bin[0..8]);
97 let overflow = self.check_overflow();
98 self.reduce(overflow as u64);
99 overflow
100 }
101
102 pub fn get_b32(&self, bin: &mut [u8; 32]) {
103 write_be64(&mut bin[0..8], self.d[3]);
104 write_be64(&mut bin[8..16], self.d[2]);
105 write_be64(&mut bin[16..24], self.d[1]);
106 write_be64(&mut bin[24..32], self.d[0]);
107 }
108
109 fn check_overflow(&self) -> bool {
110 let mut yes = 0u64;
111 let mut no = 0u64;
112 no |= (self.d[3] < N_3) as u64;
113 no |= (self.d[2] < N_2) as u64;
114 yes |= (self.d[2] > N_2) as u64 & !no;
115 no |= (self.d[1] < N_1) as u64;
116 yes |= (self.d[1] > N_1) as u64 & !no;
117 yes |= (self.d[0] >= N_0) as u64 & !no;
118 yes != 0
119 }
120
121 fn reduce(&mut self, overflow: u64) {
122 let mut t: u128 = self.d[0] as u128 + (overflow as u128 * N_C_0 as u128);
123 self.d[0] = t as u64;
124 t >>= 64;
125 t += self.d[1] as u128 + (overflow as u128 * N_C_1 as u128);
126 self.d[1] = t as u64;
127 t >>= 64;
128 t += self.d[2] as u128 + (overflow as u128 * N_C_2 as u128);
129 self.d[2] = t as u64;
130 t >>= 64;
131 t += self.d[3] as u128;
132 self.d[3] = t as u64;
133 }
134
135 pub fn is_zero(&self) -> bool {
136 (self.d[0] | self.d[1] | self.d[2] | self.d[3]) == 0
137 }
138
139 pub fn is_one(&self) -> bool {
140 (self.d[0] ^ 1) | self.d[1] | self.d[2] | self.d[3] == 0
141 }
142
143 pub fn is_odd(&self) -> bool {
145 self.d[0] & 1 != 0
146 }
147
148 pub(crate) fn is_even(&self) -> bool {
150 self.d[0] & 1 == 0
151 }
152
153 #[allow(dead_code)]
155 fn sub(&mut self, a: &Scalar, b: &Scalar) {
156 let mut neg_b = Scalar::zero();
157 neg_b.negate(b);
158 self.add(a, &neg_b);
159 }
160
161 #[allow(dead_code)]
163 fn half(&mut self) {
164 self.d[0] = (self.d[0] >> 1) | (self.d[1] << 63);
165 self.d[1] = (self.d[1] >> 1) | (self.d[2] << 63);
166 self.d[2] = (self.d[2] >> 1) | (self.d[3] << 63);
167 self.d[3] >>= 1;
168 }
169
170 #[allow(dead_code)]
172 fn half_add_n(&mut self) {
173 let mut t: u128 = self.d[0] as u128 + N_0 as u128;
174 let c0 = t as u64;
175 let mut c1 = (t >> 64) as u64;
176 t = self.d[1] as u128 + N_1 as u128 + c1 as u128;
177 c1 = t as u64;
178 let mut c2 = (t >> 64) as u64;
179 t = self.d[2] as u128 + N_2 as u128 + c2 as u128;
180 c2 = t as u64;
181 let mut c3 = (t >> 64) as u64;
182 t = self.d[3] as u128 + N_3 as u128 + c3 as u128;
183 c3 = t as u64;
184 let c4 = (t >> 64) as u64;
185 self.d[0] = (c0 >> 1) | (c1 << 63);
186 self.d[1] = (c1 >> 1) | (c2 << 63);
187 self.d[2] = (c2 >> 1) | (c3 << 63);
188 self.d[3] = (c3 >> 1) | (c4 << 63);
189 self.reduce(self.check_overflow() as u64);
190 }
191
192 pub fn div2(&mut self) {
198 let add_mask = 0u64.wrapping_sub(self.d[0] & 1);
200 let mut t: u128 = self.d[0] as u128 + (N_0 & add_mask) as u128;
201 let c0 = t as u64;
202 t >>= 64;
203 t += self.d[1] as u128 + (N_1 & add_mask) as u128;
204 let c1 = t as u64;
205 t >>= 64;
206 t += self.d[2] as u128 + (N_2 & add_mask) as u128;
207 let c2 = t as u64;
208 t >>= 64;
209 t += self.d[3] as u128 + (N_3 & add_mask) as u128;
210 let c3 = t as u64;
211 let c4 = (t >> 64) as u64;
212 self.d[0] = (c0 >> 1) | (c1 << 63);
213 self.d[1] = (c1 >> 1) | (c2 << 63);
214 self.d[2] = (c2 >> 1) | (c3 << 63);
215 self.d[3] = (c3 >> 1) | (c4 << 63);
216 self.reduce(self.check_overflow() as u64);
219 }
220
221 pub fn half_modn(&mut self, a: &Scalar) {
223 *self = *a;
224 self.div2();
225 }
226
227 #[allow(dead_code)]
229 fn add_no_reduce(a: &Scalar, b: &Scalar) -> [u64; 5] {
230 let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
231 let c0 = t as u64;
232 let mut c1 = (t >> 64) as u64;
233 t = a.d[1] as u128 + b.d[1] as u128 + c1 as u128;
234 c1 = t as u64;
235 let mut c2 = (t >> 64) as u64;
236 t = a.d[2] as u128 + b.d[2] as u128 + c2 as u128;
237 c2 = t as u64;
238 let mut c3 = (t >> 64) as u64;
239 t = a.d[3] as u128 + b.d[3] as u128 + c3 as u128;
240 c3 = t as u64;
241 let c4 = (t >> 64) as u64;
242 [c0, c1, c2, c3, c4]
243 }
244
245 #[allow(dead_code)]
247 fn set_from_5limb_half(&mut self, c: &[u64; 5]) {
248 self.d[0] = (c[0] >> 1) | (c[1] << 63);
249 self.d[1] = (c[1] >> 1) | (c[2] << 63);
250 self.d[2] = (c[2] >> 1) | (c[3] << 63);
251 self.d[3] = (c[3] >> 1) | (c[4] << 63);
252 self.reduce(self.check_overflow() as u64);
253 }
254
255 #[allow(dead_code)]
257 fn sub_half(&mut self, a: &Scalar, b: &Scalar) {
258 self.sub(a, b);
259 self.div2();
260 }
261
262 pub fn add(&mut self, a: &Scalar, b: &Scalar) -> bool {
263 let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
264 self.d[0] = t as u64;
265 t >>= 64;
266 t += a.d[1] as u128 + b.d[1] as u128;
267 self.d[1] = t as u64;
268 t >>= 64;
269 t += a.d[2] as u128 + b.d[2] as u128;
270 self.d[2] = t as u64;
271 t >>= 64;
272 t += a.d[3] as u128 + b.d[3] as u128;
273 self.d[3] = t as u64;
274 t >>= 64;
275 let overflow = t as u64 + self.check_overflow() as u64;
276 debug_assert!(overflow <= 1);
277 self.reduce(overflow);
278 overflow != 0
279 }
280
281 pub fn negate(&mut self, a: &Scalar) {
282 let nz = a.d[0] | a.d[1] | a.d[2] | a.d[3];
284 let nonzero = 0u64.wrapping_sub((nz != 0) as u64);
285 let mut t: u128 = (!a.d[0]) as u128 + (N_0 + 1) as u128;
286 self.d[0] = (t as u64) & nonzero;
287 t >>= 64;
288 t += (!a.d[1]) as u128 + N_1 as u128;
289 self.d[1] = (t as u64) & nonzero;
290 t >>= 64;
291 t += (!a.d[2]) as u128 + N_2 as u128;
292 self.d[2] = (t as u64) & nonzero;
293 t >>= 64;
294 t += (!a.d[3]) as u128 + N_3 as u128;
295 self.d[3] = (t as u64) & nonzero;
296 }
297
298 pub fn mul(&mut self, a: &Scalar, b: &Scalar) {
299 let mut l = [0u64; 8];
300 scalar_mul_512(&mut l, a, b);
301 scalar_reduce_512(self, &l);
302 }
303
304 pub fn split_lambda(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
306 const MINUS_B1: Scalar = Scalar {
307 d: [
308 (0x6F547FA9u64 << 32) | 0x0ABFE4C3,
309 (0xE4437ED6u64 << 32) | 0x010E8828,
310 0,
311 0,
312 ],
313 };
314 const MINUS_B2: Scalar = Scalar {
315 d: [
316 (0xD765CDA8u64 << 32) | 0x3DB1562C,
317 (0x8A280AC5u64 << 32) | 0x0774346D,
318 (0xFFFFFFFFu64 << 32) | 0xFFFFFFFE,
319 (0xFFFFFFFFu64 << 32) | 0xFFFFFFFF,
320 ],
321 };
322 const G1: Scalar = Scalar {
323 d: [
324 (0xE893209Au64 << 32) | 0x45DBB031,
325 (0x3DAA8A14u64 << 32) | 0x71E8CA7F,
326 (0xE86C90E4u64 << 32) | 0x9284EB15,
327 (0x3086D221u64 << 32) | 0xA7D46BCD,
328 ],
329 };
330 const G2: Scalar = Scalar {
331 d: [
332 (0x1571B4AEu64 << 32) | 0x8AC47F71,
333 (0x221208ACu64 << 32) | 0x9DF506C6,
334 (0x6F547FA9u64 << 32) | 0x0ABFE4C4,
335 (0xE4437ED6u64 << 32) | 0x010E8828,
336 ],
337 };
338
339 let mut c1 = Scalar::zero();
340 let mut c2 = Scalar::zero();
341 scalar_mul_shift_var(&mut c1, k, &G1, 384);
342 scalar_mul_shift_var(&mut c2, k, &G2, 384);
343 let mut t = Scalar::zero();
344 t.mul(&c1, &MINUS_B1);
345 c1 = t;
346 t.mul(&c2, &MINUS_B2);
347 c2 = t;
348 r2.add(&c1, &c2);
349 r1.mul(r2, &LAMBDA);
350 let mut neg = Scalar::zero();
351 neg.negate(r1);
352 r1.add(&neg, k);
353 }
354
355 pub fn get_bits_limb32(&self, offset: u32, count: u32) -> u32 {
358 debug_assert!(count > 0 && count <= 32);
359 debug_assert!((offset + count - 1) >> 6 == offset >> 6);
360 let limb = offset >> 6;
361 let shift = offset & 0x3F;
362 let mask = if count == 32 {
363 u32::MAX
364 } else {
365 (1u32 << count) - 1
366 };
367 ((self.d[limb as usize] >> shift) as u32) & mask
368 }
369
370 pub fn get_bits_var(&self, offset: u32, count: u32) -> u32 {
372 debug_assert!(count > 0 && count <= 32);
373 debug_assert!(offset + count <= 256);
374 if (offset + count - 1) >> 6 == offset >> 6 {
375 self.get_bits_limb32(offset, count)
376 } else {
377 let limb = (offset >> 6) as usize;
378 let shift = offset & 0x3F;
379 let mask = if count == 32 {
380 u32::MAX
381 } else {
382 (1u32 << count) - 1
383 };
384 let lo = self.d[limb] >> shift;
385 let hi = self.d[limb + 1].wrapping_shl(64u32 - shift);
386 ((lo | hi) as u32) & mask
387 }
388 }
389
390 pub fn split_128(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
391 r1.d[0] = k.d[0];
392 r1.d[1] = k.d[1];
393 r1.d[2] = 0;
394 r1.d[3] = 0;
395 r2.d[0] = k.d[2];
396 r2.d[1] = k.d[3];
397 r2.d[2] = 0;
398 r2.d[3] = 0;
399 }
400
401 pub fn inv(&mut self, a: &Scalar) {
405 if a.is_zero() {
406 *self = Scalar::zero();
407 return;
408 }
409 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
410 {
411 use crate::modinv64::{modinv64, SECP256K1_SCALAR_MODINV_MODINFO};
412 let mut x = scalar_to_signed62(a);
413 modinv64(&mut x, &SECP256K1_SCALAR_MODINV_MODINFO);
414 scalar_from_signed62(self, &x);
415 }
416 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
417 {
418 let a_big = scalar_to_biguint(a);
419 let n_big = scalar_to_biguint(&N);
420 let exp = &n_big - 2u32;
421 let inv_big = a_big.modpow(&exp, &n_big);
422 biguint_to_scalar(self, &inv_big);
423 }
424 }
425
426 pub fn inv_var(&mut self, a: &Scalar) {
428 self.inv(a);
429 }
430
431 pub fn is_high(&self) -> bool {
433 let mut yes = 0u64;
434 let mut no = 0u64;
435 no |= (self.d[3] < N_H_3) as u64;
436 yes |= (self.d[3] > N_H_3) as u64 & !no;
437 no |= (self.d[2] < N_H_2) as u64 & !yes;
438 no |= (self.d[1] < N_H_1) as u64 & !yes;
439 yes |= (self.d[1] > N_H_1) as u64 & !no;
440 yes |= (self.d[0] > N_H_0) as u64 & !no;
441 yes != 0
442 }
443
444 pub fn cond_negate(&mut self, flag: i32) -> i32 {
449 let mask = 0u64.wrapping_sub((flag != 0) as u64); let nonzero = 0u64.wrapping_sub((!self.is_zero()) as u64); let mut t: u128 = (self.d[0] ^ mask) as u128;
453 t += ((N_0 + 1) & mask) as u128;
454 self.d[0] = (t as u64) & nonzero;
455 t >>= 64;
456 t += (self.d[1] ^ mask) as u128;
457 t += (N_1 & mask) as u128;
458 self.d[1] = (t as u64) & nonzero;
459 t >>= 64;
460 t += (self.d[2] ^ mask) as u128;
461 t += (N_2 & mask) as u128;
462 self.d[2] = (t as u64) & nonzero;
463 t >>= 64;
464 t += (self.d[3] ^ mask) as u128;
465 t += (N_3 & mask) as u128;
466 self.d[3] = (t as u64) & nonzero;
467 ((mask >> 63) as i32) * 2 - 1
471 }
472}
473
474impl ConstantTimeEq for Scalar {
475 fn ct_eq(&self, other: &Self) -> Choice {
476 self.d[0].ct_eq(&other.d[0])
477 & self.d[1].ct_eq(&other.d[1])
478 & self.d[2].ct_eq(&other.d[2])
479 & self.d[3].ct_eq(&other.d[3])
480 }
481}
482
483impl ConditionallySelectable for Scalar {
484 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
485 Self {
486 d: [
487 u64::conditional_select(&a.d[0], &b.d[0], choice),
488 u64::conditional_select(&a.d[1], &b.d[1], choice),
489 u64::conditional_select(&a.d[2], &b.d[2], choice),
490 u64::conditional_select(&a.d[3], &b.d[3], choice),
491 ],
492 }
493 }
494}
495
496#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
498fn scalar_to_signed62(a: &Scalar) -> crate::modinv64::Signed62 {
499 const M62: u64 = u64::MAX >> 2;
500 let d = &a.d;
501 crate::modinv64::Signed62 {
502 v: [
503 (d[0] & M62) as i64,
504 ((d[0] >> 62 | d[1] << 2) & M62) as i64,
505 ((d[1] >> 60 | d[2] << 4) & M62) as i64,
506 ((d[2] >> 58 | d[3] << 6) & M62) as i64,
507 (d[3] >> 56) as i64,
508 ],
509 }
510}
511
512#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
514fn scalar_from_signed62(r: &mut Scalar, a: &crate::modinv64::Signed62) {
515 let v = &a.v;
516 r.d[0] = (v[0] as u64) | ((v[1] as u64) << 62);
517 r.d[1] = ((v[1] as u64) >> 2) | ((v[2] as u64) << 60);
518 r.d[2] = ((v[2] as u64) >> 4) | ((v[3] as u64) << 58);
519 r.d[3] = ((v[3] as u64) >> 6) | ((v[4] as u64) << 56);
520}
521
522#[allow(dead_code)]
523fn scalar_to_biguint(s: &Scalar) -> BigUint {
524 let mut bytes = [0u8; 32];
525 s.get_b32(&mut bytes);
526 BigUint::from_bytes_be(&bytes)
527}
528
529#[allow(dead_code)]
530fn biguint_to_scalar(r: &mut Scalar, b: &BigUint) {
531 let bytes = b.to_bytes_be();
532 let mut buf = [0u8; 32];
533 let len = bytes.len().min(32);
534 let start = 32 - len;
535 buf[start..].copy_from_slice(&bytes[..len]);
536 r.set_b32(&buf);
537}
538
539fn read_be64(b: &[u8]) -> u64 {
540 u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
541}
542
543fn write_be64(b: &mut [u8], v: u64) {
544 b[0..8].copy_from_slice(&v.to_be_bytes());
545}
546
547fn scalar_mul_512(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
548 #[cfg(target_arch = "x86_64")]
549 {
550 unsafe {
551 scalar_asm::scalar_mul_512_asm(l.as_mut_ptr(), a, b);
552 }
553 }
554 #[cfg(not(target_arch = "x86_64"))]
555 {
556 scalar_mul_512_rust(l, a, b);
557 }
558}
559
560#[cfg(not(target_arch = "x86_64"))]
561fn scalar_mul_512_rust(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
562 let mut c0: u64 = 0;
563 let mut c1: u64 = 0;
564 let mut c2: u32 = 0;
565
566 macro_rules! muladd_fast {
567 ($a:expr, $b:expr) => {{
568 let prod = ($a as u128) * ($b as u128);
569 let prod_lo = prod as u64;
570 let prod_hi = (prod >> 64) as u64;
571 let (lo, o) = c0.overflowing_add(prod_lo);
572 c0 = lo;
573 c1 += prod_hi + o as u64; }};
575 }
576 macro_rules! muladd {
577 ($a:expr, $b:expr) => {{
578 let prod = ($a as u128) * ($b as u128);
579 let hi = (prod >> 64) as u64;
580 let (lo, o1) = c0.overflowing_add(prod as u64);
581 c0 = lo;
582 let th = hi + o1 as u64;
583 let (mid, o2) = c1.overflowing_add(th);
584 c1 = mid;
585 c2 += o2 as u32;
586 }};
587 }
588 macro_rules! sumadd {
589 ($a:expr) => {{
590 let (lo, o) = c0.overflowing_add($a);
591 c0 = lo;
592 c1 += o as u64;
593 c2 += (c1 == 0 && o) as u32;
594 }};
595 }
596 macro_rules! extract {
597 () => {{
598 let n = c0;
599 c0 = c1;
600 c1 = c2 as u64;
601 c2 = 0;
602 n
603 }};
604 }
605 macro_rules! extract_fast {
606 () => {{
607 let n = c0;
608 c0 = c1;
609 c1 = 0;
610 n
611 }};
612 }
613
614 muladd_fast!(a.d[0], b.d[0]);
615 l[0] = extract_fast!();
616 muladd!(a.d[0], b.d[1]);
617 muladd!(a.d[1], b.d[0]);
618 l[1] = extract!();
619 muladd!(a.d[0], b.d[2]);
620 muladd!(a.d[1], b.d[1]);
621 muladd!(a.d[2], b.d[0]);
622 l[2] = extract!();
623 muladd!(a.d[0], b.d[3]);
624 muladd!(a.d[1], b.d[2]);
625 muladd!(a.d[2], b.d[1]);
626 muladd!(a.d[3], b.d[0]);
627 l[3] = extract!();
628 muladd!(a.d[1], b.d[3]);
629 muladd!(a.d[2], b.d[2]);
630 muladd!(a.d[3], b.d[1]);
631 l[4] = extract!();
632 muladd!(a.d[2], b.d[3]);
633 muladd!(a.d[3], b.d[2]);
634 l[5] = extract!();
635 muladd_fast!(a.d[3], b.d[3]);
636 l[6] = extract_fast!();
637 l[7] = c0;
638}
639
640#[allow(dead_code)]
641fn limbs_512_to_biguint(l: &[u64; 8]) -> BigUint {
642 let mut acc = BigUint::from(0u64);
643 for (i, &limb) in l.iter().enumerate() {
644 acc += BigUint::from(limb) << (64 * i);
645 }
646 acc
647}
648
649#[cfg(not(target_arch = "x86_64"))]
652fn scalar_reduce_512_limbs(r: &mut Scalar, l: &[u64; 8]) {
653 let n0 = l[4];
654 let n1 = l[5];
655 let n2 = l[6];
656 let n3 = l[7];
657
658 let mut c0: u64 = l[0];
659 let mut c1: u64 = 0;
660 let mut c2: u32 = 0;
661
662 macro_rules! muladd_fast {
663 ($a:expr, $b:expr) => {{
664 let prod = ($a as u128) * ($b as u128);
665 let (lo, o) = c0.overflowing_add(prod as u64);
666 c0 = lo;
667 c1 += (prod >> 64) as u64 + o as u64;
668 }};
669 }
670 macro_rules! muladd {
671 ($a:expr, $b:expr) => {{
672 let prod = ($a as u128) * ($b as u128);
673 let (lo, o1) = c0.overflowing_add(prod as u64);
674 c0 = lo;
675 let th = (prod >> 64) as u64 + o1 as u64;
676 let (mid, o2) = c1.overflowing_add(th);
677 c1 = mid;
678 c2 += o2 as u32;
679 }};
680 }
681 macro_rules! sumadd_fast {
682 ($a:expr) => {{
683 let (lo, o) = c0.overflowing_add($a);
684 c0 = lo;
685 c1 += o as u64;
686 }};
687 }
688 macro_rules! sumadd {
689 ($a:expr) => {{
690 let (lo, o) = c0.overflowing_add($a);
691 c0 = lo;
692 let (mid, o2) = c1.overflowing_add(o as u64);
693 c1 = mid;
694 c2 += o2 as u32;
695 }};
696 }
697 macro_rules! extract {
698 () => {{
699 let n = c0;
700 c0 = c1;
701 c1 = c2 as u64;
702 c2 = 0;
703 n
704 }};
705 }
706 macro_rules! extract_fast {
707 () => {{
708 let n = c0;
709 c0 = c1;
710 c1 = 0;
711 n
712 }};
713 }
714
715 muladd_fast!(n0, N_C_0);
717 let m0 = extract_fast!();
718 sumadd_fast!(l[1]);
719 muladd!(n1, N_C_0);
720 muladd!(n0, N_C_1);
721 let m1 = extract!();
722 sumadd!(l[2]);
723 muladd!(n2, N_C_0);
724 muladd!(n1, N_C_1);
725 sumadd!(n0);
726 let m2 = extract!();
727 sumadd!(l[3]);
728 muladd!(n3, N_C_0);
729 muladd!(n2, N_C_1);
730 sumadd!(n1);
731 let m3 = extract!();
732 muladd!(n3, N_C_1);
733 sumadd!(n2);
734 let m4 = extract!();
735 sumadd_fast!(n3);
736 let m5 = extract_fast!();
737 let m6 = c0 as u32;
738
739 c0 = m0;
741 c1 = 0;
742 c2 = 0;
743 muladd_fast!(m4, N_C_0);
744 let p0 = extract_fast!();
745 sumadd_fast!(m1);
746 muladd!(m5, N_C_0);
747 muladd!(m4, N_C_1);
748 let p1 = extract!();
749 sumadd!(m2);
750 muladd!(m6 as u64, N_C_0);
751 muladd!(m5, N_C_1);
752 sumadd!(m4);
753 let p2 = extract!();
754 sumadd_fast!(m3);
755 muladd_fast!(m6 as u64, N_C_1);
756 sumadd_fast!(m5);
757 let p3 = extract_fast!();
758 let p4 = (c0 + m6 as u64) as u32;
759
760 let mut t: u128 = p0 as u128;
762 t += (N_C_0 as u128) * (p4 as u128);
763 r.d[0] = t as u64;
764 t >>= 64;
765 t += p1 as u128;
766 t += (N_C_1 as u128) * (p4 as u128);
767 r.d[1] = t as u64;
768 t >>= 64;
769 t += p2 as u128;
770 t += p4 as u128;
771 r.d[2] = t as u64;
772 t >>= 64;
773 t += p3 as u128;
774 r.d[3] = t as u64;
775 let c = (t >> 64) as u64;
776
777 scalar_reduce(r, c + scalar_check_overflow(r));
779}
780
781fn scalar_reduce(r: &mut Scalar, overflow: u64) {
783 let of = overflow as u128;
784 let mut t: u128 = r.d[0] as u128;
785 t += of * (N_C_0 as u128);
786 r.d[0] = t as u64;
787 t >>= 64;
788 t += r.d[1] as u128;
789 t += of * (N_C_1 as u128);
790 r.d[1] = t as u64;
791 t >>= 64;
792 t += r.d[2] as u128;
793 t += of * (N_C_2 as u128);
794 r.d[2] = t as u64;
795 t >>= 64;
796 r.d[3] = (t as u64).wrapping_add(r.d[3]);
797}
798
799fn scalar_check_overflow(r: &Scalar) -> u64 {
801 let mut yes = 0u64;
802 let mut no = 0u64;
803 no |= (r.d[3] < N_3) as u64;
804 no |= (r.d[2] < N_2) as u64;
805 yes |= (r.d[2] > N_2) as u64 & !no;
806 no |= (r.d[1] < N_1) as u64;
807 yes |= (r.d[1] > N_1) as u64 & !no;
808 yes |= (r.d[0] >= N_0) as u64 & !no;
809 yes
810}
811
812fn scalar_reduce_512(r: &mut Scalar, l: &[u64; 8]) {
813 #[cfg(target_arch = "x86_64")]
814 {
815 let c = unsafe { scalar_asm::scalar_reduce_512_asm(r, l.as_ptr()) };
816 scalar_reduce(r, c + scalar_check_overflow(r));
817 }
818 #[cfg(not(target_arch = "x86_64"))]
819 {
820 scalar_reduce_512_limbs(r, l);
821 }
822}
823
824#[cfg(test)]
825#[test]
826fn test_scalar_reduce_n_plus_1() {
827 let l = [N_0 + 1, N_1, N_2, N_3, 0, 0, 0, 0];
828 let mut r = Scalar::zero();
829 scalar_reduce_512(&mut r, &l);
830 assert!(r.is_one(), "(n+1) mod n = 1, got r.d = {:?}", r.d);
831}
832
833#[cfg(test)]
834#[test]
835fn test_scalar_mul_inv2_times_2() {
836 let inv2_hex = "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1";
837 let inv2_bytes = hex::decode(inv2_hex).unwrap();
838 let mut buf = [0u8; 32];
839 buf.copy_from_slice(&inv2_bytes);
840 let mut inv2 = Scalar::zero();
841 inv2.set_b32(&buf);
842 let mut two = Scalar::zero();
843 two.set_int(2);
844 let mut l = [0u64; 8];
845 scalar_mul_512(&mut l, &inv2, &two);
846 let mut r = Scalar::zero();
847 scalar_reduce_512(&mut r, &l);
848 assert!(r.is_one(), "inv2*2 mod n = 1");
849}
850
851fn scalar_mul_shift_var(r: &mut Scalar, a: &Scalar, b: &Scalar, shift: u32) {
852 assert!(shift >= 256);
853 let mut l = [0u64; 8];
854 scalar_mul_512(&mut l, a, b);
855 let shiftlimbs = (shift >> 6) as usize;
856 let shiftlow = shift & 0x3F;
857 let shifthigh = 64 - shiftlow;
858 r.d[0] = if shift < 512 {
859 (l[shiftlimbs] >> shiftlow)
860 | (if shift < 448 && shiftlow != 0 {
861 l[1 + shiftlimbs] << shifthigh
862 } else {
863 0
864 })
865 } else {
866 0
867 };
868 r.d[1] = if shift < 448 {
869 (l[1 + shiftlimbs] >> shiftlow)
870 | (if shift < 384 && shiftlow != 0 {
871 l[2 + shiftlimbs] << shifthigh
872 } else {
873 0
874 })
875 } else {
876 0
877 };
878 r.d[2] = if shift < 384 {
879 (l[2 + shiftlimbs] >> shiftlow)
880 | (if shift < 320 && shiftlow != 0 {
881 l[3 + shiftlimbs] << shifthigh
882 } else {
883 0
884 })
885 } else {
886 0
887 };
888 r.d[3] = if shift < 320 {
889 l[3 + shiftlimbs] >> shiftlow
890 } else {
891 0
892 };
893 let bit = (l[(shift - 1) as usize >> 6] >> ((shift - 1) & 0x3F)) & 1;
894 scalar_cadd_bit(r, 0, bit != 0);
895}
896
897fn scalar_cadd_bit(r: &mut Scalar, bit: u32, flag: bool) {
898 let bit = if flag { bit } else { bit + 256 };
899 if bit >= 256 {
900 return;
901 }
902 let mut t: u128 = r.d[0] as u128
903 + if (bit >> 6) == 0 {
904 1u128 << (bit & 0x3F)
905 } else {
906 0
907 };
908 r.d[0] = t as u64;
909 t >>= 64;
910 t += r.d[1] as u128
911 + if (bit >> 6) == 1 {
912 1u128 << (bit & 0x3F)
913 } else {
914 0
915 };
916 r.d[1] = t as u64;
917 t >>= 64;
918 t += r.d[2] as u128
919 + if (bit >> 6) == 2 {
920 1u128 << (bit & 0x3F)
921 } else {
922 0
923 };
924 r.d[2] = t as u64;
925 t >>= 64;
926 t += r.d[3] as u128
927 + if (bit >> 6) == 3 {
928 1u128 << (bit & 0x3F)
929 } else {
930 0
931 };
932 r.d[3] = t as u64;
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938
939 #[test]
940 fn test_split_lambda_identity() {
941 let mut k = Scalar::zero();
943 k.set_int(42);
944
945 let mut r1 = Scalar::zero();
946 let mut r2 = Scalar::zero();
947 Scalar::split_lambda(&mut r1, &mut r2, &k);
948
949 let mut lambda_r2 = Scalar::zero();
950 lambda_r2.mul(&r2, &LAMBDA);
951 let mut check = Scalar::zero();
952 check.add(&r1, &lambda_r2);
953 assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
954 }
955
956 #[test]
957 fn test_split_lambda_neg_three() {
958 let mut three = Scalar::zero();
959 three.set_int(3);
960 let mut k = Scalar::zero();
961 k.negate(&three); let mut r1 = Scalar::zero();
964 let mut r2 = Scalar::zero();
965 Scalar::split_lambda(&mut r1, &mut r2, &k);
966
967 let mut lambda_r2 = Scalar::zero();
968 lambda_r2.mul(&r2, &LAMBDA);
969 let mut check = Scalar::zero();
970 check.add(&r1, &lambda_r2);
971 assert!(
972 bool::from(check.ct_eq(&k)),
973 "r1 + lambda*r2 should equal k for k=-3"
974 );
975 }
976
977 #[test]
978 fn test_split_lambda_ecdsa_scalar() {
979 let mut k = Scalar::zero();
980 k.d = [
981 11125243483441707226,
982 2149109665766520832,
983 14302025600096445326,
984 4162584031737161978,
985 ];
986
987 let n_big = scalar_to_biguint(&N);
988
989 let mut r1 = Scalar::zero();
990 let mut r2 = Scalar::zero();
991 Scalar::split_lambda(&mut r1, &mut r2, &k);
992
993 let r1_big = scalar_to_biguint(&r1);
994 let r2_big = scalar_to_biguint(&r2);
995 let n_half = &n_big / BigUint::from(2u64);
996 let r1_abs = if r1_big > n_half {
997 &n_big - &r1_big
998 } else {
999 r1_big.clone()
1000 };
1001 let r2_abs = if r2_big > n_half {
1002 &n_big - &r2_big
1003 } else {
1004 r2_big.clone()
1005 };
1006 assert!(
1007 r1_abs.bits() <= 128,
1008 "|r1| exceeds 128 bits: {}",
1009 r1_abs.bits()
1010 );
1011 assert!(
1012 r2_abs.bits() <= 128,
1013 "|r2| exceeds 128 bits: {}",
1014 r2_abs.bits()
1015 );
1016
1017 let mut lambda_r2 = Scalar::zero();
1018 lambda_r2.mul(&r2, &LAMBDA);
1019 let mut check = Scalar::zero();
1020 check.add(&r1, &lambda_r2);
1021 assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
1022 }
1023
1024 #[test]
1025 fn test_split_128_identity() {
1026 let mut k = Scalar::zero();
1028 k.set_int(0x1234_5678);
1029
1030 let mut r1 = Scalar::zero();
1031 let mut r2 = Scalar::zero();
1032 Scalar::split_128(&mut r1, &mut r2, &k);
1033
1034 let mut two_128 = Scalar::zero();
1035 two_128.d[2] = 1;
1036 let mut r2_shifted = Scalar::zero();
1037 r2_shifted.mul(&r2, &two_128);
1038 let mut check = Scalar::zero();
1039 check.add(&r1, &r2_shifted);
1040 assert!(bool::from(check.ct_eq(&k)), "r1 + 2^128*r2 should equal k");
1041 }
1042}