1use sha2::{Digest, Sha256};
10use subtle::Choice;
11
12use crate::ecmult;
13use crate::ecmult_gen_const;
14use crate::field::FieldElement;
15use crate::group::{Ge, Gej};
16use crate::scalar::Scalar;
17
18pub(crate) fn tagged_hash(tag: &[u8], data: &[u8]) -> [u8; 32] {
21 let midstate = tagged_hash_midstate(tag);
22 tagged_hash_from_midstate(&midstate, data)
23}
24
25fn tagged_hash_midstate(tag: &[u8]) -> Sha256 {
28 let tag_hash = Sha256::digest(tag);
29 let mut hasher = Sha256::new();
30 hasher.update(tag_hash);
31 hasher.update(tag_hash);
32 hasher
33}
34
35#[inline(always)]
36fn tagged_hash_from_midstate(midstate: &Sha256, data: &[u8]) -> [u8; 32] {
37 let mut h = midstate.clone();
38 h.update(data);
39 h.finalize().into()
40}
41
42use std::sync::OnceLock;
43
44fn challenge_midstate() -> &'static Sha256 {
45 static MID: OnceLock<Sha256> = OnceLock::new();
46 MID.get_or_init(|| tagged_hash_midstate(b"BIP0340/challenge"))
47}
48
49fn nonce_midstate() -> &'static Sha256 {
50 static MID: OnceLock<Sha256> = OnceLock::new();
51 MID.get_or_init(|| tagged_hash_midstate(b"BIP0340/nonce"))
52}
53
54fn aux_midstate() -> &'static Sha256 {
55 static MID: OnceLock<Sha256> = OnceLock::new();
56 MID.get_or_init(|| tagged_hash_midstate(b"BIP0340/aux"))
57}
58
59#[inline(always)]
61pub(crate) fn tagged_hash_challenge(data: &[u8]) -> [u8; 32] {
62 tagged_hash_from_midstate(challenge_midstate(), data)
63}
64
65#[inline(always)]
68pub fn lift_x(x_bytes: &[u8; 32]) -> Option<Ge> {
69 let mut x = FieldElement::zero();
70 if !x.set_b32_limit(x_bytes) {
71 return None;
72 }
73 let mut p = Ge::default();
74 if !p.set_xo_var(&x, false) {
75 return None;
76 }
77 Some(p)
78}
79
80fn ge_to_xonly(p: &Ge) -> [u8; 32] {
83 let mut x = p.x;
84 x.normalize();
85 let mut out = [0u8; 32];
86 x.get_b32(&mut out);
87 out
88}
89
90pub fn xonly_pubkey_from_secret(seckey: &[u8; 32]) -> Option<[u8; 32]> {
93 let mut d = Scalar::zero();
94 if d.set_b32(seckey) {
95 return None;
96 }
97 if d.is_zero() {
98 return None;
99 }
100 let mut pj = Gej::default();
101 ecmult_gen_const(&mut pj, &d);
102 let mut p = Ge::default();
103 p.set_gej_var(&pj);
104 p.x.normalize();
105 let mut out = [0u8; 32];
107 p.x.get_b32(&mut out);
108 Some(out)
109}
110
111#[inline(always)]
116pub fn schnorr_verify(sig64: &[u8; 64], msg: &[u8], pubkey_x32: &[u8; 32]) -> bool {
117 let r_bytes: [u8; 32] = sig64[0..32].try_into().unwrap();
118 let s_bytes: [u8; 32] = sig64[32..64].try_into().unwrap();
119
120 let mut r_fe = FieldElement::zero();
121 if !r_fe.set_b32_limit(&r_bytes) {
122 return false; }
124
125 let mut s = Scalar::zero();
126 if s.set_b32(&s_bytes) {
127 return false; }
129
130 let p = match lift_x(pubkey_x32) {
131 Some(p) => p,
132 None => return false,
133 };
134
135 let e_hash = {
136 let mut h = challenge_midstate().clone();
137 h.update(r_bytes);
138 h.update(pubkey_x32);
139 h.update(msg);
140 let result: [u8; 32] = h.finalize().into();
141 result
142 };
143
144 let mut e = Scalar::zero();
145 e.set_b32(&e_hash);
146
147 let mut neg_e = Scalar::zero();
148 neg_e.negate(&e);
149
150 let mut pj = Gej::default();
151 pj.set_ge(&p);
152
153 let mut rj = Gej::default();
155 ecmult::ecmult(&mut rj, &pj, &neg_e, Some(&s));
156
157 if rj.infinity {
158 return false;
159 }
160
161 let mut r_computed = Ge::default();
162 r_computed.set_gej_var(&rj);
163 r_computed.x.normalize();
164 r_computed.y.normalize();
165
166 if r_computed.y.is_odd() {
167 return false;
168 }
169
170 let mut r_computed_bytes = [0u8; 32];
171 r_computed.x.get_b32(&mut r_computed_bytes);
172
173 r_computed_bytes == r_bytes
174}
175
176pub fn schnorr_verify_batch(sigs: &[[u8; 64]], msgs: &[&[u8]], pubkeys: &[[u8; 32]]) -> bool {
180 let n = sigs.len().min(msgs.len()).min(pubkeys.len());
181 if n == 0 {
182 return true;
183 }
184 if n == 1 {
185 return schnorr_verify(&sigs[0], msgs[0], &pubkeys[0]);
186 }
187
188 match schnorr_verify_batch_inner(sigs, msgs, pubkeys, n) {
190 Some(ok) => ok,
191 None => (0..n).all(|i| schnorr_verify(&sigs[i], msgs[i], &pubkeys[i])),
192 }
193}
194
195fn schnorr_verify_batch_inner(
198 sigs: &[[u8; 64]],
199 msgs: &[&[u8]],
200 pubkeys: &[[u8; 32]],
201 n: usize,
202) -> Option<bool> {
203 if n <= 16 {
204 schnorr_verify_batch_stack_16(sigs, msgs, pubkeys, n)
205 } else if n <= 64 {
206 schnorr_verify_batch_stack_64(sigs, msgs, pubkeys, n)
207 } else {
208 schnorr_verify_batch_heap(sigs, msgs, pubkeys, n)
209 }
210}
211
212struct BatchParseState<'a> {
213 s_scalars: &'a mut [Scalar],
214 e_scalars: &'a mut [Scalar],
215 pk_points: &'a mut [Ge],
216 r_points: &'a mut [Ge],
217 randoms: &'a mut [Scalar],
218}
219
220fn schnorr_verify_batch_parse(
221 sigs: &[[u8; 64]],
222 msgs: &[&[u8]],
223 pubkeys: &[[u8; 32]],
224 n: usize,
225 state: &mut BatchParseState<'_>,
226) -> Option<()> {
227 for i in 0..n {
228 let r_bytes: [u8; 32] = sigs[i][0..32].try_into().unwrap();
229 let s_bytes: [u8; 32] = sigs[i][32..64].try_into().unwrap();
230
231 let mut r_fe = FieldElement::zero();
232 if !r_fe.set_b32_limit(&r_bytes) {
233 return None;
234 }
235
236 let mut s = Scalar::zero();
237 if s.set_b32(&s_bytes) {
238 return None;
239 }
240 state.s_scalars[i] = s;
241
242 state.pk_points[i] = lift_x(&pubkeys[i])?;
243
244 let e_hash = {
245 let mut h = challenge_midstate().clone();
246 h.update(r_bytes);
247 h.update(pubkeys[i]);
248 h.update(msgs[i]);
249 let result: [u8; 32] = h.finalize().into();
250 result
251 };
252
253 let mut e = Scalar::zero();
254 e.set_b32(&e_hash);
255 state.e_scalars[i] = e;
256
257 state.r_points[i] = lift_x(&r_bytes)?;
258
259 let mut random = Scalar::zero();
260 let _ = random.set_b32(&r_bytes);
261 let mut random_tmp = Scalar::zero();
262 random_tmp.add(&random, &e);
263 state.randoms[i] = random_tmp;
264 }
265 Some(())
266}
267
268fn schnorr_verify_batch_stack_16(
269 sigs: &[[u8; 64]],
270 msgs: &[&[u8]],
271 pubkeys: &[[u8; 32]],
272 n: usize,
273) -> Option<bool> {
274 let mut s_scalars = [Scalar::zero(); 16];
275 let mut e_scalars = [Scalar::zero(); 16];
276 let mut pk_points = [Ge::default(); 16];
277 let mut r_points = [Ge::default(); 16];
278 let mut randoms = [Scalar::zero(); 16];
279 schnorr_verify_batch_parse(
280 sigs,
281 msgs,
282 pubkeys,
283 n,
284 &mut BatchParseState {
285 s_scalars: &mut s_scalars,
286 e_scalars: &mut e_scalars,
287 pk_points: &mut pk_points,
288 r_points: &mut r_points,
289 randoms: &mut randoms,
290 },
291 )?;
292
293 let mut g_scalar = Scalar::zero();
294 for i in 0..n {
295 let mut term = Scalar::zero();
296 term.mul(&randoms[i], &s_scalars[i]);
297 let mut g_new = Scalar::zero();
298 g_new.add(&g_scalar, &term);
299 g_scalar = g_new;
300 }
301
302 let mut all_scalars = [Scalar::zero(); 32];
303 let mut all_points = [Ge::default(); 32];
304 for i in 0..n {
305 let mut neg_e = Scalar::zero();
306 neg_e.negate(&e_scalars[i]);
307 let mut z_neg_e = Scalar::zero();
308 z_neg_e.mul(&randoms[i], &neg_e);
309 all_scalars[i] = z_neg_e;
310 all_points[i] = pk_points[i];
311 }
312 for i in 0..n {
313 let mut neg_rand = Scalar::zero();
314 neg_rand.negate(&randoms[i]);
315 all_scalars[n + i] = neg_rand;
316 all_points[n + i] = r_points[i];
317 }
318
319 let mut result = Gej::default();
320 ecmult::ecmult_multi(
321 &mut result,
322 &g_scalar,
323 &all_scalars[..2 * n],
324 &all_points[..2 * n],
325 );
326 Some(result.is_infinity())
327}
328
329fn schnorr_verify_batch_stack_64(
330 sigs: &[[u8; 64]],
331 msgs: &[&[u8]],
332 pubkeys: &[[u8; 32]],
333 n: usize,
334) -> Option<bool> {
335 let mut s_scalars = Vec::with_capacity(n);
336 let mut e_scalars = Vec::with_capacity(n);
337 let mut pk_points = Vec::with_capacity(n);
338 let mut r_points = Vec::with_capacity(n);
339 let mut randoms = Vec::with_capacity(n);
340 s_scalars.resize(n, Scalar::zero());
341 e_scalars.resize(n, Scalar::zero());
342 pk_points.resize(n, Ge::default());
343 r_points.resize(n, Ge::default());
344 randoms.resize(n, Scalar::zero());
345 schnorr_verify_batch_parse(
346 sigs,
347 msgs,
348 pubkeys,
349 n,
350 &mut BatchParseState {
351 s_scalars: &mut s_scalars,
352 e_scalars: &mut e_scalars,
353 pk_points: &mut pk_points,
354 r_points: &mut r_points,
355 randoms: &mut randoms,
356 },
357 )?;
358
359 let mut g_scalar = Scalar::zero();
360 for i in 0..n {
361 let mut term = Scalar::zero();
362 term.mul(&randoms[i], &s_scalars[i]);
363 let mut g_new = Scalar::zero();
364 g_new.add(&g_scalar, &term);
365 g_scalar = g_new;
366 }
367
368 let mut all_scalars = [Scalar::zero(); 128];
369 let mut all_points = [Ge::default(); 128];
370 for i in 0..n {
371 let mut neg_e = Scalar::zero();
372 neg_e.negate(&e_scalars[i]);
373 let mut z_neg_e = Scalar::zero();
374 z_neg_e.mul(&randoms[i], &neg_e);
375 all_scalars[i] = z_neg_e;
376 all_points[i] = pk_points[i];
377 }
378 for i in 0..n {
379 let mut neg_rand = Scalar::zero();
380 neg_rand.negate(&randoms[i]);
381 all_scalars[n + i] = neg_rand;
382 all_points[n + i] = r_points[i];
383 }
384
385 let mut result = Gej::default();
386 ecmult::ecmult_multi(
387 &mut result,
388 &g_scalar,
389 &all_scalars[..2 * n],
390 &all_points[..2 * n],
391 );
392 Some(result.is_infinity())
393}
394
395fn schnorr_verify_batch_heap(
396 sigs: &[[u8; 64]],
397 msgs: &[&[u8]],
398 pubkeys: &[[u8; 32]],
399 n: usize,
400) -> Option<bool> {
401 let mut s_scalars = Vec::with_capacity(n);
402 let mut e_scalars = Vec::with_capacity(n);
403 let mut pk_points = Vec::with_capacity(n);
404 let mut r_points = Vec::with_capacity(n);
405 let mut randoms = Vec::with_capacity(n);
406
407 for i in 0..n {
408 let r_bytes: [u8; 32] = sigs[i][0..32].try_into().unwrap();
409 let s_bytes: [u8; 32] = sigs[i][32..64].try_into().unwrap();
410
411 let mut r_fe = FieldElement::zero();
412 if !r_fe.set_b32_limit(&r_bytes) {
413 return None;
414 }
415
416 let mut s = Scalar::zero();
417 if s.set_b32(&s_bytes) {
418 return None;
419 }
420 s_scalars.push(s);
421
422 let p = lift_x(&pubkeys[i])?;
423 pk_points.push(p);
424
425 let e_hash = {
426 let mut h = challenge_midstate().clone();
427 h.update(r_bytes);
428 h.update(pubkeys[i]);
429 h.update(msgs[i]);
430 let result: [u8; 32] = h.finalize().into();
431 result
432 };
433
434 let mut e = Scalar::zero();
435 e.set_b32(&e_hash);
436 e_scalars.push(e);
437
438 let r_ge = lift_x(&r_bytes)?;
439 r_points.push(r_ge);
440
441 let mut random = Scalar::zero();
442 let _ = random.set_b32(&r_bytes);
443 let mut random_tmp = Scalar::zero();
444 random_tmp.add(&random, &e);
445 randoms.push(random_tmp);
446 }
447
448 let mut g_scalar = Scalar::zero();
449 for i in 0..n {
450 let mut term = Scalar::zero();
451 term.mul(&randoms[i], &s_scalars[i]);
452 let mut g_new = Scalar::zero();
453 g_new.add(&g_scalar, &term);
454 g_scalar = g_new;
455 }
456
457 let mut all_scalars = Vec::with_capacity(2 * n);
458 let mut all_points = Vec::with_capacity(2 * n);
459
460 for i in 0..n {
461 let mut neg_e = Scalar::zero();
462 neg_e.negate(&e_scalars[i]);
463 let mut z_neg_e = Scalar::zero();
464 z_neg_e.mul(&randoms[i], &neg_e);
465 all_scalars.push(z_neg_e);
466 all_points.push(pk_points[i]);
467 }
468 for i in 0..n {
469 let mut neg_rand = Scalar::zero();
470 neg_rand.negate(&randoms[i]);
471 all_scalars.push(neg_rand);
472 all_points.push(r_points[i]);
473 }
474
475 let mut result = Gej::default();
476 ecmult::ecmult_multi(&mut result, &g_scalar, &all_scalars, &all_points);
477 Some(result.is_infinity())
478}
479
480pub fn schnorr_sign(seckey: &[u8; 32], msg: &[u8], aux_rand32: &[u8; 32]) -> Option<[u8; 64]> {
485 let mut d = Scalar::zero();
486 if d.set_b32(seckey) {
487 return None;
488 }
489 if d.is_zero() {
490 return None;
491 }
492
493 let mut pj = Gej::default();
496 ecmult_gen_const(&mut pj, &d);
497 let mut p_ge = Ge::default();
498 p_ge.set_gej_var(&pj);
499 p_ge.x.normalize();
500 p_ge.y.normalize();
501 let pk_parity_odd = p_ge.y.is_odd();
502 let pk = ge_to_xonly(&p_ge);
503
504 schnorr_sign_inner(seckey, &d, pk_parity_odd, &pk, msg, aux_rand32)
505}
506
507#[derive(Clone, Copy, Debug)]
511pub struct Keypair {
512 pub(crate) seckey: [u8; 32],
516 pub(crate) pubkey_xonly: [u8; 32],
518 pub(crate) pk_parity_odd: bool,
521}
522
523impl Keypair {
524 pub fn from_seckey(seckey32: &[u8; 32]) -> Option<Self> {
528 let mut d = Scalar::zero();
529 if d.set_b32(seckey32) {
530 return None;
531 }
532 if d.is_zero() {
533 return None;
534 }
535 let mut pj = Gej::default();
536 ecmult_gen_const(&mut pj, &d);
537 let mut p = Ge::default();
538 p.set_gej_var(&pj);
539 p.x.normalize();
540 p.y.normalize();
541 let pk_parity_odd = p.y.is_odd();
542 let pubkey_xonly = ge_to_xonly(&p);
543 Some(Keypair {
544 seckey: *seckey32,
545 pubkey_xonly,
546 pk_parity_odd,
547 })
548 }
549
550 #[inline]
552 pub fn xonly_pubkey(&self) -> [u8; 32] {
553 self.pubkey_xonly
554 }
555}
556
557pub fn schnorr_sign_with_keypair(
562 keypair: &Keypair,
563 msg: &[u8],
564 aux_rand32: &[u8; 32],
565) -> Option<[u8; 64]> {
566 let mut d = Scalar::zero();
567 if d.set_b32(&keypair.seckey) {
568 return None;
569 }
570 if d.is_zero() {
571 return None;
572 }
573 schnorr_sign_inner(
574 &keypair.seckey,
575 &d,
576 keypair.pk_parity_odd,
577 &keypair.pubkey_xonly,
578 msg,
579 aux_rand32,
580 )
581}
582
583#[inline]
584fn schnorr_sign_inner(
585 seckey: &[u8; 32],
586 d: &Scalar,
587 pk_parity_odd: bool,
588 pk: &[u8; 32],
589 msg: &[u8],
590 aux_rand32: &[u8; 32],
591) -> Option<[u8; 64]> {
592 let mut d_adj = *d;
593 if pk_parity_odd {
594 d_adj.negate(d);
595 }
596
597 let aux_hash = tagged_hash_from_midstate(aux_midstate(), aux_rand32);
598 let mut masked_key = [0u8; 32];
599 for i in 0..32 {
600 masked_key[i] = seckey[i] ^ aux_hash[i];
601 }
602
603 let k_hash = {
604 let mut h = nonce_midstate().clone();
605 h.update(masked_key);
606 h.update(*pk);
607 h.update(msg);
608 let result: [u8; 32] = h.finalize().into();
609 result
610 };
611
612 let mut k = Scalar::zero();
613 if k.set_b32(&k_hash) {
614 return None;
615 }
616 if k.is_zero() {
617 return None;
618 }
619
620 let mut rj = Gej::default();
622 ecmult_gen_const(&mut rj, &k);
623 let mut r_ge = Ge::default();
624 r_ge.set_gej_var(&rj);
625 r_ge.x.normalize();
626 r_ge.y.normalize();
627
628 let parity = r_ge.y.is_odd() as i32;
631 k.cond_negate(parity);
632 let mut neg_ry = FieldElement::zero();
633 neg_ry.negate(&r_ge.y, 1);
634 neg_ry.normalize();
635 r_ge.y.normalize();
636 r_ge.y.cmov(&neg_ry, Choice::from(parity as u8));
637
638 let mut r_bytes = [0u8; 32];
639 r_ge.x.get_b32(&mut r_bytes);
640
641 let e_hash = {
642 let mut h = challenge_midstate().clone();
643 h.update(r_bytes);
644 h.update(*pk);
645 h.update(msg);
646 let result: [u8; 32] = h.finalize().into();
647 result
648 };
649
650 let mut e = Scalar::zero();
651 e.set_b32(&e_hash);
652
653 let mut ed = Scalar::zero();
655 ed.mul(&e, &d_adj);
656 let mut s = Scalar::zero();
657 s.add(&k, &ed);
658
659 let mut s_bytes = [0u8; 32];
660 s.get_b32(&mut s_bytes);
661
662 let mut sig = [0u8; 64];
663 sig[0..32].copy_from_slice(&r_bytes);
664 sig[32..64].copy_from_slice(&s_bytes);
665
666 Some(sig)
667}