1#[cfg(target_arch = "x86_64")]
6mod scalar_asm {
7 use super::Scalar;
8
9 extern "C" {
10 fn blvm_secp256k1_scalar_mul_512(l8: *mut u64, a: *const Scalar, b: *const Scalar);
13
14 fn blvm_secp256k1_scalar_reduce_512(r: *mut Scalar, l: *const u64) -> u64;
17 }
18
19 #[inline(always)]
20 pub(super) unsafe fn scalar_mul_512_asm(l: *mut u64, a: *const Scalar, b: *const Scalar) {
21 blvm_secp256k1_scalar_mul_512(l, a, b);
22 }
23
24 #[inline(always)]
25 pub(super) unsafe fn scalar_reduce_512_asm(r: *mut Scalar, l: *const u64) -> u64 {
26 blvm_secp256k1_scalar_reduce_512(r, l)
27 }
28}
29
30use num_bigint::BigUint;
31use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
32
33#[repr(C)]
35#[derive(Clone, Copy, Debug)]
36pub struct Scalar {
37 pub d: [u64; 4],
38}
39
40const N_0: u64 = 0xBFD25E8CD0364141;
42const N_1: u64 = 0xBAAEDCE6AF48A03B;
43const N_2: u64 = 0xFFFFFFFFFFFFFFFE;
44const N_3: u64 = 0xFFFFFFFFFFFFFFFF;
45
46const N_C_0: u64 = 0x402DA1732FC9BEBF;
48const N_C_1: u64 = 0x4551231950B75FC4;
49const N_C_2: u64 = 1;
50
51const N_H_0: u64 = 0xDFE92F46681B20A0;
53const N_H_1: u64 = 0x5D576E7357A4501D;
54const N_H_2: u64 = 0xFFFFFFFFFFFFFFFF;
55const N_H_3: u64 = 0x7FFFFFFFFFFFFFFF;
56
57#[allow(dead_code)]
59const N: Scalar = Scalar {
60 d: [N_0, N_1, N_2, N_3],
61};
62
63const LAMBDA: Scalar = Scalar {
64 d: [
65 0xDF02967C1B23BD72,
66 0x122E22EA20816678,
67 0xA5261C028812645A,
68 0x5363AD4CC05C30E0,
69 ],
70};
71
72impl Scalar {
73 pub fn zero() -> Self {
74 Self { d: [0, 0, 0, 0] }
75 }
76
77 pub fn one() -> Self {
78 Self { d: [1, 0, 0, 0] }
79 }
80
81 pub fn set_int(&mut self, v: u32) {
82 self.d[0] = v as u64;
83 self.d[1] = 0;
84 self.d[2] = 0;
85 self.d[3] = 0;
86 }
87
88 pub fn set_b32(&mut self, bin: &[u8; 32]) -> bool {
90 self.d[0] = read_be64(&bin[24..32]);
91 self.d[1] = read_be64(&bin[16..24]);
92 self.d[2] = read_be64(&bin[8..16]);
93 self.d[3] = read_be64(&bin[0..8]);
94 let overflow = self.check_overflow();
95 self.reduce(overflow as u64);
96 overflow
97 }
98
99 pub fn get_b32(&self, bin: &mut [u8; 32]) {
100 write_be64(&mut bin[0..8], self.d[3]);
101 write_be64(&mut bin[8..16], self.d[2]);
102 write_be64(&mut bin[16..24], self.d[1]);
103 write_be64(&mut bin[24..32], self.d[0]);
104 }
105
106 fn check_overflow(&self) -> bool {
107 let mut yes = 0u64;
108 let mut no = 0u64;
109 no |= (self.d[3] < N_3) as u64;
110 no |= (self.d[2] < N_2) as u64;
111 yes |= (self.d[2] > N_2) as u64 & !no;
112 no |= (self.d[1] < N_1) as u64;
113 yes |= (self.d[1] > N_1) as u64 & !no;
114 yes |= (self.d[0] >= N_0) as u64 & !no;
115 yes != 0
116 }
117
118 fn reduce(&mut self, overflow: u64) {
119 let mut t: u128 = self.d[0] as u128 + (overflow as u128 * N_C_0 as u128);
120 self.d[0] = t as u64;
121 t >>= 64;
122 t += self.d[1] as u128 + (overflow as u128 * N_C_1 as u128);
123 self.d[1] = t as u64;
124 t >>= 64;
125 t += self.d[2] as u128 + (overflow as u128 * N_C_2 as u128);
126 self.d[2] = t as u64;
127 t >>= 64;
128 t += self.d[3] as u128;
129 self.d[3] = t as u64;
130 }
131
132 pub fn is_zero(&self) -> bool {
133 (self.d[0] | self.d[1] | self.d[2] | self.d[3]) == 0
134 }
135
136 pub fn is_one(&self) -> bool {
137 (self.d[0] ^ 1) | self.d[1] | self.d[2] | self.d[3] == 0
138 }
139
140 #[allow(dead_code)]
142 fn is_odd(&self) -> bool {
143 self.d[0] & 1 != 0
144 }
145
146 pub(crate) fn is_even(&self) -> bool {
148 self.d[0] & 1 == 0
149 }
150
151 #[allow(dead_code)]
153 fn sub(&mut self, a: &Scalar, b: &Scalar) {
154 let mut neg_b = Scalar::zero();
155 neg_b.negate(b);
156 self.add(a, &neg_b);
157 }
158
159 #[allow(dead_code)]
161 fn half(&mut self) {
162 self.d[0] = (self.d[0] >> 1) | (self.d[1] << 63);
163 self.d[1] = (self.d[1] >> 1) | (self.d[2] << 63);
164 self.d[2] = (self.d[2] >> 1) | (self.d[3] << 63);
165 self.d[3] >>= 1;
166 }
167
168 #[allow(dead_code)]
170 fn half_add_n(&mut self) {
171 let mut t: u128 = self.d[0] as u128 + N_0 as u128;
172 let c0 = t as u64;
173 let mut c1 = (t >> 64) as u64;
174 t = self.d[1] as u128 + N_1 as u128 + c1 as u128;
175 c1 = t as u64;
176 let mut c2 = (t >> 64) as u64;
177 t = self.d[2] as u128 + N_2 as u128 + c2 as u128;
178 c2 = t as u64;
179 let mut c3 = (t >> 64) as u64;
180 t = self.d[3] as u128 + N_3 as u128 + c3 as u128;
181 c3 = t as u64;
182 let c4 = (t >> 64) as u64;
183 self.d[0] = (c0 >> 1) | (c1 << 63);
184 self.d[1] = (c1 >> 1) | (c2 << 63);
185 self.d[2] = (c2 >> 1) | (c3 << 63);
186 self.d[3] = (c3 >> 1) | (c4 << 63);
187 self.reduce(self.check_overflow() as u64);
188 }
189
190 #[allow(dead_code)]
192 fn div2(&mut self) {
193 if self.is_odd() {
194 self.half_add_n();
195 } else {
196 self.half();
197 }
198 }
199
200 #[allow(dead_code)]
202 fn add_no_reduce(a: &Scalar, b: &Scalar) -> [u64; 5] {
203 let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
204 let c0 = t as u64;
205 let mut c1 = (t >> 64) as u64;
206 t = a.d[1] as u128 + b.d[1] as u128 + c1 as u128;
207 c1 = t as u64;
208 let mut c2 = (t >> 64) as u64;
209 t = a.d[2] as u128 + b.d[2] as u128 + c2 as u128;
210 c2 = t as u64;
211 let mut c3 = (t >> 64) as u64;
212 t = a.d[3] as u128 + b.d[3] as u128 + c3 as u128;
213 c3 = t as u64;
214 let c4 = (t >> 64) as u64;
215 [c0, c1, c2, c3, c4]
216 }
217
218 #[allow(dead_code)]
220 fn set_from_5limb_half(&mut self, c: &[u64; 5]) {
221 self.d[0] = (c[0] >> 1) | (c[1] << 63);
222 self.d[1] = (c[1] >> 1) | (c[2] << 63);
223 self.d[2] = (c[2] >> 1) | (c[3] << 63);
224 self.d[3] = (c[3] >> 1) | (c[4] << 63);
225 self.reduce(self.check_overflow() as u64);
226 }
227
228 #[allow(dead_code)]
230 fn sub_half(&mut self, a: &Scalar, b: &Scalar) {
231 self.sub(a, b);
232 self.div2();
233 }
234
235 pub fn add(&mut self, a: &Scalar, b: &Scalar) -> bool {
236 let mut t: u128 = a.d[0] as u128 + b.d[0] as u128;
237 self.d[0] = t as u64;
238 t >>= 64;
239 t += a.d[1] as u128 + b.d[1] as u128;
240 self.d[1] = t as u64;
241 t >>= 64;
242 t += a.d[2] as u128 + b.d[2] as u128;
243 self.d[2] = t as u64;
244 t >>= 64;
245 t += a.d[3] as u128 + b.d[3] as u128;
246 self.d[3] = t as u64;
247 t >>= 64;
248 let overflow = t as u64 + self.check_overflow() as u64;
249 debug_assert!(overflow <= 1);
250 self.reduce(overflow);
251 overflow != 0
252 }
253
254 pub fn negate(&mut self, a: &Scalar) {
255 let nonzero = if a.is_zero() { 0u64 } else { u64::MAX };
256 let mut t: u128 = (!a.d[0]) as u128 + (N_0 + 1) as u128;
257 self.d[0] = (t as u64) & nonzero;
258 t >>= 64;
259 t += (!a.d[1]) as u128 + N_1 as u128;
260 self.d[1] = (t as u64) & nonzero;
261 t >>= 64;
262 t += (!a.d[2]) as u128 + N_2 as u128;
263 self.d[2] = (t as u64) & nonzero;
264 t >>= 64;
265 t += (!a.d[3]) as u128 + N_3 as u128;
266 self.d[3] = (t as u64) & nonzero;
267 }
268
269 pub fn mul(&mut self, a: &Scalar, b: &Scalar) {
270 let mut l = [0u64; 8];
271 scalar_mul_512(&mut l, a, b);
272 scalar_reduce_512(self, &l);
273 }
274
275 pub fn split_lambda(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
277 const MINUS_B1: Scalar = Scalar {
278 d: [
279 (0x6F547FA9u64 << 32) | 0x0ABFE4C3,
280 (0xE4437ED6u64 << 32) | 0x010E8828,
281 0,
282 0,
283 ],
284 };
285 const MINUS_B2: Scalar = Scalar {
286 d: [
287 (0xD765CDA8u64 << 32) | 0x3DB1562C,
288 (0x8A280AC5u64 << 32) | 0x0774346D,
289 (0xFFFFFFFFu64 << 32) | 0xFFFFFFFE,
290 (0xFFFFFFFFu64 << 32) | 0xFFFFFFFF,
291 ],
292 };
293 const G1: Scalar = Scalar {
294 d: [
295 (0xE893209Au64 << 32) | 0x45DBB031,
296 (0x3DAA8A14u64 << 32) | 0x71E8CA7F,
297 (0xE86C90E4u64 << 32) | 0x9284EB15,
298 (0x3086D221u64 << 32) | 0xA7D46BCD,
299 ],
300 };
301 const G2: Scalar = Scalar {
302 d: [
303 (0x1571B4AEu64 << 32) | 0x8AC47F71,
304 (0x221208ACu64 << 32) | 0x9DF506C6,
305 (0x6F547FA9u64 << 32) | 0x0ABFE4C4,
306 (0xE4437ED6u64 << 32) | 0x010E8828,
307 ],
308 };
309
310 let mut c1 = Scalar::zero();
311 let mut c2 = Scalar::zero();
312 scalar_mul_shift_var(&mut c1, k, &G1, 384);
313 scalar_mul_shift_var(&mut c2, k, &G2, 384);
314 let mut t = Scalar::zero();
315 t.mul(&c1, &MINUS_B1);
316 c1 = t;
317 t.mul(&c2, &MINUS_B2);
318 c2 = t;
319 r2.add(&c1, &c2);
320 r1.mul(r2, &LAMBDA);
321 let mut neg = Scalar::zero();
322 neg.negate(r1);
323 r1.add(&neg, k);
324 }
325
326 pub fn get_bits_limb32(&self, offset: u32, count: u32) -> u32 {
329 debug_assert!(count > 0 && count <= 32);
330 debug_assert!((offset + count - 1) >> 6 == offset >> 6);
331 let limb = offset >> 6;
332 let shift = offset & 0x3F;
333 let mask = if count == 32 {
334 u32::MAX
335 } else {
336 (1u32 << count) - 1
337 };
338 ((self.d[limb as usize] >> shift) as u32) & mask
339 }
340
341 pub fn get_bits_var(&self, offset: u32, count: u32) -> u32 {
343 debug_assert!(count > 0 && count <= 32);
344 debug_assert!(offset + count <= 256);
345 if (offset + count - 1) >> 6 == offset >> 6 {
346 self.get_bits_limb32(offset, count)
347 } else {
348 let limb = (offset >> 6) as usize;
349 let shift = offset & 0x3F;
350 let mask = if count == 32 {
351 u32::MAX
352 } else {
353 (1u32 << count) - 1
354 };
355 let lo = self.d[limb] >> shift;
356 let hi = self.d[limb + 1].wrapping_shl(64u32 - shift);
357 ((lo | hi) as u32) & mask
358 }
359 }
360
361 pub fn split_128(r1: &mut Scalar, r2: &mut Scalar, k: &Scalar) {
362 r1.d[0] = k.d[0];
363 r1.d[1] = k.d[1];
364 r1.d[2] = 0;
365 r1.d[3] = 0;
366 r2.d[0] = k.d[2];
367 r2.d[1] = k.d[3];
368 r2.d[2] = 0;
369 r2.d[3] = 0;
370 }
371
372 pub fn inv_var(&mut self, a: &Scalar) {
375 if a.is_zero() {
376 *self = Scalar::zero();
377 return;
378 }
379 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
380 {
381 use crate::modinv64::{modinv64, SECP256K1_SCALAR_MODINV_MODINFO};
382 let mut x = scalar_to_signed62(a);
383 modinv64(&mut x, &SECP256K1_SCALAR_MODINV_MODINFO);
384 scalar_from_signed62(self, &x);
385 }
386 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
387 {
388 let a_big = scalar_to_biguint(a);
389 let n_big = scalar_to_biguint(&N);
390 let exp = &n_big - 2u32;
391 let inv_big = a_big.modpow(&exp, &n_big);
392 biguint_to_scalar(self, &inv_big);
393 }
394 }
395
396 pub fn is_high(&self) -> bool {
398 let mut yes = 0u64;
399 let mut no = 0u64;
400 no |= (self.d[3] < N_H_3) as u64;
401 yes |= (self.d[3] > N_H_3) as u64 & !no;
402 no |= (self.d[2] < N_H_2) as u64 & !yes;
403 no |= (self.d[1] < N_H_1) as u64 & !yes;
404 yes |= (self.d[1] > N_H_1) as u64 & !no;
405 yes |= (self.d[0] > N_H_0) as u64 & !no;
406 yes != 0
407 }
408
409 pub fn cond_negate(&mut self, flag: i32) -> i32 {
411 let mask = if flag != 0 { u64::MAX } else { 0 };
412 let nonzero = if self.is_zero() { 0 } else { u64::MAX };
413 let mut t: u128 = (self.d[0] ^ mask) as u128;
414 t += ((N_0 + 1) & mask) as u128;
415 self.d[0] = (t as u64) & nonzero;
416 t >>= 64;
417 t += (self.d[1] ^ mask) as u128;
418 t += (N_1 & mask) as u128;
419 self.d[1] = (t as u64) & nonzero;
420 t >>= 64;
421 t += (self.d[2] ^ mask) as u128;
422 t += (N_2 & mask) as u128;
423 self.d[2] = (t as u64) & nonzero;
424 t >>= 64;
425 t += (self.d[3] ^ mask) as u128;
426 t += (N_3 & mask) as u128;
427 self.d[3] = (t as u64) & nonzero;
428 if mask == 0 {
429 -1
430 } else {
431 1
432 }
433 }
434}
435
436impl ConstantTimeEq for Scalar {
437 fn ct_eq(&self, other: &Self) -> Choice {
438 self.d[0].ct_eq(&other.d[0])
439 & self.d[1].ct_eq(&other.d[1])
440 & self.d[2].ct_eq(&other.d[2])
441 & self.d[3].ct_eq(&other.d[3])
442 }
443}
444
445impl ConditionallySelectable for Scalar {
446 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
447 Self {
448 d: [
449 u64::conditional_select(&a.d[0], &b.d[0], choice),
450 u64::conditional_select(&a.d[1], &b.d[1], choice),
451 u64::conditional_select(&a.d[2], &b.d[2], choice),
452 u64::conditional_select(&a.d[3], &b.d[3], choice),
453 ],
454 }
455 }
456}
457
458#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
460fn scalar_to_signed62(a: &Scalar) -> crate::modinv64::Signed62 {
461 const M62: u64 = u64::MAX >> 2;
462 let d = &a.d;
463 crate::modinv64::Signed62 {
464 v: [
465 (d[0] & M62) as i64,
466 ((d[0] >> 62 | d[1] << 2) & M62) as i64,
467 ((d[1] >> 60 | d[2] << 4) & M62) as i64,
468 ((d[2] >> 58 | d[3] << 6) & M62) as i64,
469 (d[3] >> 56) as i64,
470 ],
471 }
472}
473
474#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
476fn scalar_from_signed62(r: &mut Scalar, a: &crate::modinv64::Signed62) {
477 let v = &a.v;
478 r.d[0] = (v[0] as u64) | ((v[1] as u64) << 62);
479 r.d[1] = ((v[1] as u64) >> 2) | ((v[2] as u64) << 60);
480 r.d[2] = ((v[2] as u64) >> 4) | ((v[3] as u64) << 58);
481 r.d[3] = ((v[3] as u64) >> 6) | ((v[4] as u64) << 56);
482}
483
484#[allow(dead_code)]
485fn scalar_to_biguint(s: &Scalar) -> BigUint {
486 let mut bytes = [0u8; 32];
487 s.get_b32(&mut bytes);
488 BigUint::from_bytes_be(&bytes)
489}
490
491#[allow(dead_code)]
492fn biguint_to_scalar(r: &mut Scalar, b: &BigUint) {
493 let bytes = b.to_bytes_be();
494 let mut buf = [0u8; 32];
495 let len = bytes.len().min(32);
496 let start = 32 - len;
497 buf[start..].copy_from_slice(&bytes[..len]);
498 r.set_b32(&buf);
499}
500
501fn read_be64(b: &[u8]) -> u64 {
502 u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
503}
504
505fn write_be64(b: &mut [u8], v: u64) {
506 b[0..8].copy_from_slice(&v.to_be_bytes());
507}
508
509fn scalar_mul_512(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
510 #[cfg(target_arch = "x86_64")]
511 {
512 unsafe {
513 scalar_asm::scalar_mul_512_asm(l.as_mut_ptr(), a, b);
514 }
515 }
516 #[cfg(not(target_arch = "x86_64"))]
517 {
518 scalar_mul_512_rust(l, a, b);
519 }
520}
521
522#[cfg(not(target_arch = "x86_64"))]
523fn scalar_mul_512_rust(l: &mut [u64; 8], a: &Scalar, b: &Scalar) {
524 let mut c0: u64 = 0;
525 let mut c1: u64 = 0;
526 let mut c2: u32 = 0;
527
528 macro_rules! muladd_fast {
529 ($a:expr, $b:expr) => {{
530 let prod = ($a as u128) * ($b as u128);
531 let prod_lo = prod as u64;
532 let prod_hi = (prod >> 64) as u64;
533 let (lo, o) = c0.overflowing_add(prod_lo);
534 c0 = lo;
535 c1 += prod_hi + o as u64; }};
537 }
538 macro_rules! muladd {
539 ($a:expr, $b:expr) => {{
540 let prod = ($a as u128) * ($b as u128);
541 let hi = (prod >> 64) as u64;
542 let (lo, o1) = c0.overflowing_add(prod as u64);
543 c0 = lo;
544 let th = hi + o1 as u64;
545 let (mid, o2) = c1.overflowing_add(th);
546 c1 = mid;
547 c2 += o2 as u32;
548 }};
549 }
550 macro_rules! sumadd {
551 ($a:expr) => {{
552 let (lo, o) = c0.overflowing_add($a);
553 c0 = lo;
554 c1 += o as u64;
555 c2 += (c1 == 0 && o) as u32;
556 }};
557 }
558 macro_rules! extract {
559 () => {{
560 let n = c0;
561 c0 = c1;
562 c1 = c2 as u64;
563 c2 = 0;
564 n
565 }};
566 }
567 macro_rules! extract_fast {
568 () => {{
569 let n = c0;
570 c0 = c1;
571 c1 = 0;
572 n
573 }};
574 }
575
576 muladd_fast!(a.d[0], b.d[0]);
577 l[0] = extract_fast!();
578 muladd!(a.d[0], b.d[1]);
579 muladd!(a.d[1], b.d[0]);
580 l[1] = extract!();
581 muladd!(a.d[0], b.d[2]);
582 muladd!(a.d[1], b.d[1]);
583 muladd!(a.d[2], b.d[0]);
584 l[2] = extract!();
585 muladd!(a.d[0], b.d[3]);
586 muladd!(a.d[1], b.d[2]);
587 muladd!(a.d[2], b.d[1]);
588 muladd!(a.d[3], b.d[0]);
589 l[3] = extract!();
590 muladd!(a.d[1], b.d[3]);
591 muladd!(a.d[2], b.d[2]);
592 muladd!(a.d[3], b.d[1]);
593 l[4] = extract!();
594 muladd!(a.d[2], b.d[3]);
595 muladd!(a.d[3], b.d[2]);
596 l[5] = extract!();
597 muladd_fast!(a.d[3], b.d[3]);
598 l[6] = extract_fast!();
599 l[7] = c0;
600}
601
602#[allow(dead_code)]
603fn limbs_512_to_biguint(l: &[u64; 8]) -> BigUint {
604 let mut acc = BigUint::from(0u64);
605 for (i, &limb) in l.iter().enumerate() {
606 acc += BigUint::from(limb) << (64 * i);
607 }
608 acc
609}
610
611#[cfg(not(target_arch = "x86_64"))]
614fn scalar_reduce_512_limbs(r: &mut Scalar, l: &[u64; 8]) {
615 let n0 = l[4];
616 let n1 = l[5];
617 let n2 = l[6];
618 let n3 = l[7];
619
620 let mut c0: u64 = l[0];
621 let mut c1: u64 = 0;
622 let mut c2: u32 = 0;
623
624 macro_rules! muladd_fast {
625 ($a:expr, $b:expr) => {{
626 let prod = ($a as u128) * ($b as u128);
627 let (lo, o) = c0.overflowing_add(prod as u64);
628 c0 = lo;
629 c1 += (prod >> 64) as u64 + o as u64;
630 }};
631 }
632 macro_rules! muladd {
633 ($a:expr, $b:expr) => {{
634 let prod = ($a as u128) * ($b as u128);
635 let (lo, o1) = c0.overflowing_add(prod as u64);
636 c0 = lo;
637 let th = (prod >> 64) as u64 + o1 as u64;
638 let (mid, o2) = c1.overflowing_add(th);
639 c1 = mid;
640 c2 += o2 as u32;
641 }};
642 }
643 macro_rules! sumadd_fast {
644 ($a:expr) => {{
645 let (lo, o) = c0.overflowing_add($a);
646 c0 = lo;
647 c1 += o as u64;
648 }};
649 }
650 macro_rules! sumadd {
651 ($a:expr) => {{
652 let (lo, o) = c0.overflowing_add($a);
653 c0 = lo;
654 let (mid, o2) = c1.overflowing_add(o as u64);
655 c1 = mid;
656 c2 += o2 as u32;
657 }};
658 }
659 macro_rules! extract {
660 () => {{
661 let n = c0;
662 c0 = c1;
663 c1 = c2 as u64;
664 c2 = 0;
665 n
666 }};
667 }
668 macro_rules! extract_fast {
669 () => {{
670 let n = c0;
671 c0 = c1;
672 c1 = 0;
673 n
674 }};
675 }
676
677 muladd_fast!(n0, N_C_0);
679 let m0 = extract_fast!();
680 sumadd_fast!(l[1]);
681 muladd!(n1, N_C_0);
682 muladd!(n0, N_C_1);
683 let m1 = extract!();
684 sumadd!(l[2]);
685 muladd!(n2, N_C_0);
686 muladd!(n1, N_C_1);
687 sumadd!(n0);
688 let m2 = extract!();
689 sumadd!(l[3]);
690 muladd!(n3, N_C_0);
691 muladd!(n2, N_C_1);
692 sumadd!(n1);
693 let m3 = extract!();
694 muladd!(n3, N_C_1);
695 sumadd!(n2);
696 let m4 = extract!();
697 sumadd_fast!(n3);
698 let m5 = extract_fast!();
699 let m6 = c0 as u32;
700
701 c0 = m0;
703 c1 = 0;
704 c2 = 0;
705 muladd_fast!(m4, N_C_0);
706 let p0 = extract_fast!();
707 sumadd_fast!(m1);
708 muladd!(m5, N_C_0);
709 muladd!(m4, N_C_1);
710 let p1 = extract!();
711 sumadd!(m2);
712 muladd!(m6 as u64, N_C_0);
713 muladd!(m5, N_C_1);
714 sumadd!(m4);
715 let p2 = extract!();
716 sumadd_fast!(m3);
717 muladd_fast!(m6 as u64, N_C_1);
718 sumadd_fast!(m5);
719 let p3 = extract_fast!();
720 let p4 = (c0 + m6 as u64) as u32;
721
722 let mut t: u128 = p0 as u128;
724 t += (N_C_0 as u128) * (p4 as u128);
725 r.d[0] = t as u64;
726 t >>= 64;
727 t += p1 as u128;
728 t += (N_C_1 as u128) * (p4 as u128);
729 r.d[1] = t as u64;
730 t >>= 64;
731 t += p2 as u128;
732 t += p4 as u128;
733 r.d[2] = t as u64;
734 t >>= 64;
735 t += p3 as u128;
736 r.d[3] = t as u64;
737 let c = (t >> 64) as u64;
738
739 scalar_reduce(r, c + scalar_check_overflow(r));
741}
742
743fn scalar_reduce(r: &mut Scalar, overflow: u64) {
745 let of = overflow as u128;
746 let mut t: u128 = r.d[0] as u128;
747 t += of * (N_C_0 as u128);
748 r.d[0] = t as u64;
749 t >>= 64;
750 t += r.d[1] as u128;
751 t += of * (N_C_1 as u128);
752 r.d[1] = t as u64;
753 t >>= 64;
754 t += r.d[2] as u128;
755 t += of * (N_C_2 as u128);
756 r.d[2] = t as u64;
757 t >>= 64;
758 r.d[3] = (t as u64).wrapping_add(r.d[3]);
759}
760
761fn scalar_check_overflow(r: &Scalar) -> u64 {
763 let mut yes = 0u64;
764 let mut no = 0u64;
765 no |= (r.d[3] < N_3) as u64;
766 no |= (r.d[2] < N_2) as u64;
767 yes |= (r.d[2] > N_2) as u64 & !no;
768 no |= (r.d[1] < N_1) as u64;
769 yes |= (r.d[1] > N_1) as u64 & !no;
770 yes |= (r.d[0] >= N_0) as u64 & !no;
771 yes
772}
773
774fn scalar_reduce_512(r: &mut Scalar, l: &[u64; 8]) {
775 #[cfg(target_arch = "x86_64")]
776 {
777 let c = unsafe { scalar_asm::scalar_reduce_512_asm(r, l.as_ptr()) };
778 scalar_reduce(r, c + scalar_check_overflow(r));
779 }
780 #[cfg(not(target_arch = "x86_64"))]
781 {
782 scalar_reduce_512_limbs(r, l);
783 }
784}
785
786#[cfg(test)]
787#[test]
788fn test_scalar_reduce_n_plus_1() {
789 let l = [N_0 + 1, N_1, N_2, N_3, 0, 0, 0, 0];
790 let mut r = Scalar::zero();
791 scalar_reduce_512(&mut r, &l);
792 assert!(r.is_one(), "(n+1) mod n = 1, got r.d = {:?}", r.d);
793}
794
795#[cfg(test)]
796#[test]
797fn test_scalar_mul_inv2_times_2() {
798 let inv2_hex = "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1";
799 let inv2_bytes = hex::decode(inv2_hex).unwrap();
800 let mut buf = [0u8; 32];
801 buf.copy_from_slice(&inv2_bytes);
802 let mut inv2 = Scalar::zero();
803 inv2.set_b32(&buf);
804 let mut two = Scalar::zero();
805 two.set_int(2);
806 let mut l = [0u64; 8];
807 scalar_mul_512(&mut l, &inv2, &two);
808 let mut r = Scalar::zero();
809 scalar_reduce_512(&mut r, &l);
810 assert!(r.is_one(), "inv2*2 mod n = 1");
811}
812
813fn scalar_mul_shift_var(r: &mut Scalar, a: &Scalar, b: &Scalar, shift: u32) {
814 assert!(shift >= 256);
815 let mut l = [0u64; 8];
816 scalar_mul_512(&mut l, a, b);
817 let shiftlimbs = (shift >> 6) as usize;
818 let shiftlow = shift & 0x3F;
819 let shifthigh = 64 - shiftlow;
820 r.d[0] = if shift < 512 {
821 (l[shiftlimbs] >> shiftlow)
822 | (if shift < 448 && shiftlow != 0 {
823 l[1 + shiftlimbs] << shifthigh
824 } else {
825 0
826 })
827 } else {
828 0
829 };
830 r.d[1] = if shift < 448 {
831 (l[1 + shiftlimbs] >> shiftlow)
832 | (if shift < 384 && shiftlow != 0 {
833 l[2 + shiftlimbs] << shifthigh
834 } else {
835 0
836 })
837 } else {
838 0
839 };
840 r.d[2] = if shift < 384 {
841 (l[2 + shiftlimbs] >> shiftlow)
842 | (if shift < 320 && shiftlow != 0 {
843 l[3 + shiftlimbs] << shifthigh
844 } else {
845 0
846 })
847 } else {
848 0
849 };
850 r.d[3] = if shift < 320 {
851 l[3 + shiftlimbs] >> shiftlow
852 } else {
853 0
854 };
855 let bit = (l[(shift - 1) as usize >> 6] >> ((shift - 1) & 0x3F)) & 1;
856 scalar_cadd_bit(r, 0, bit != 0);
857}
858
859fn scalar_cadd_bit(r: &mut Scalar, bit: u32, flag: bool) {
860 let bit = if flag { bit } else { bit + 256 };
861 if bit >= 256 {
862 return;
863 }
864 let mut t: u128 = r.d[0] as u128
865 + if (bit >> 6) == 0 {
866 1u128 << (bit & 0x3F)
867 } else {
868 0
869 };
870 r.d[0] = t as u64;
871 t >>= 64;
872 t += r.d[1] as u128
873 + if (bit >> 6) == 1 {
874 1u128 << (bit & 0x3F)
875 } else {
876 0
877 };
878 r.d[1] = t as u64;
879 t >>= 64;
880 t += r.d[2] as u128
881 + if (bit >> 6) == 2 {
882 1u128 << (bit & 0x3F)
883 } else {
884 0
885 };
886 r.d[2] = t as u64;
887 t >>= 64;
888 t += r.d[3] as u128
889 + if (bit >> 6) == 3 {
890 1u128 << (bit & 0x3F)
891 } else {
892 0
893 };
894 r.d[3] = t as u64;
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900
901 #[test]
902 fn test_split_lambda_identity() {
903 let mut k = Scalar::zero();
905 k.set_int(42);
906
907 let mut r1 = Scalar::zero();
908 let mut r2 = Scalar::zero();
909 Scalar::split_lambda(&mut r1, &mut r2, &k);
910
911 let mut lambda_r2 = Scalar::zero();
912 lambda_r2.mul(&r2, &LAMBDA);
913 let mut check = Scalar::zero();
914 check.add(&r1, &lambda_r2);
915 assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
916 }
917
918 #[test]
919 fn test_split_lambda_neg_three() {
920 let mut three = Scalar::zero();
921 three.set_int(3);
922 let mut k = Scalar::zero();
923 k.negate(&three); let mut r1 = Scalar::zero();
926 let mut r2 = Scalar::zero();
927 Scalar::split_lambda(&mut r1, &mut r2, &k);
928
929 let mut lambda_r2 = Scalar::zero();
930 lambda_r2.mul(&r2, &LAMBDA);
931 let mut check = Scalar::zero();
932 check.add(&r1, &lambda_r2);
933 assert!(
934 bool::from(check.ct_eq(&k)),
935 "r1 + lambda*r2 should equal k for k=-3"
936 );
937 }
938
939 #[test]
940 fn test_split_lambda_ecdsa_scalar() {
941 let mut k = Scalar::zero();
942 k.d = [
943 11125243483441707226,
944 2149109665766520832,
945 14302025600096445326,
946 4162584031737161978,
947 ];
948
949 let n_big = scalar_to_biguint(&N);
950
951 let mut r1 = Scalar::zero();
952 let mut r2 = Scalar::zero();
953 Scalar::split_lambda(&mut r1, &mut r2, &k);
954
955 let r1_big = scalar_to_biguint(&r1);
956 let r2_big = scalar_to_biguint(&r2);
957 let n_half = &n_big / BigUint::from(2u64);
958 let r1_abs = if r1_big > n_half {
959 &n_big - &r1_big
960 } else {
961 r1_big.clone()
962 };
963 let r2_abs = if r2_big > n_half {
964 &n_big - &r2_big
965 } else {
966 r2_big.clone()
967 };
968 assert!(
969 r1_abs.bits() <= 128,
970 "|r1| exceeds 128 bits: {}",
971 r1_abs.bits()
972 );
973 assert!(
974 r2_abs.bits() <= 128,
975 "|r2| exceeds 128 bits: {}",
976 r2_abs.bits()
977 );
978
979 let mut lambda_r2 = Scalar::zero();
980 lambda_r2.mul(&r2, &LAMBDA);
981 let mut check = Scalar::zero();
982 check.add(&r1, &lambda_r2);
983 assert!(bool::from(check.ct_eq(&k)), "r1 + lambda*r2 should equal k");
984 }
985
986 #[test]
987 fn test_split_128_identity() {
988 let mut k = Scalar::zero();
990 k.set_int(0x1234_5678);
991
992 let mut r1 = Scalar::zero();
993 let mut r2 = Scalar::zero();
994 Scalar::split_128(&mut r1, &mut r2, &k);
995
996 let mut two_128 = Scalar::zero();
997 two_128.d[2] = 1;
998 let mut r2_shifted = Scalar::zero();
999 r2_shifted.mul(&r2, &two_128);
1000 let mut check = Scalar::zero();
1001 check.add(&r1, &r2_shifted);
1002 assert!(bool::from(check.ct_eq(&k)), "r1 + 2^128*r2 should equal k");
1003 }
1004}