1use blst::{
14 blst_bendian_from_scalar, blst_fp12, blst_fr, blst_fr_add, blst_fr_from_scalar,
15 blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_sub, blst_hash_to_g1,
16 blst_hash_to_g2, blst_keygen, blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_compress,
17 blst_p1_from_affine, blst_p1_in_g1, blst_p1_is_inf, blst_p1_mult, blst_p1_to_affine,
18 blst_p1_uncompress, blst_p2, blst_p2_add_or_double, blst_p2_affine, blst_p2_compress,
19 blst_p2_from_affine, blst_p2_in_g2, blst_p2_is_inf, blst_p2_mult, blst_p2_to_affine,
20 blst_p2_uncompress, blst_scalar, blst_scalar_from_bendian, blst_scalar_from_fr, blst_sk_check,
21 Pairing, BLS12_381_G1, BLS12_381_G2, BLS12_381_NEG_G1, BLST_ERROR,
22};
23use bytes::{Buf, BufMut};
24use commonware_codec::{
25 varint::UInt,
26 EncodeSize,
27 Error::{self, Invalid},
28 FixedSize, Read, ReadExt, Write,
29};
30use commonware_utils::hex;
31use rand::RngCore;
32use std::{
33 fmt::{Debug, Display},
34 hash::{Hash, Hasher},
35 ptr,
36};
37use zeroize::{Zeroize, ZeroizeOnDrop};
38
39pub type DST = &'static [u8];
43
44pub trait Element:
46 Read<Cfg = ()> + Write + FixedSize + Clone + Eq + PartialEq + Send + Sync
47{
48 fn zero() -> Self;
50
51 fn one() -> Self;
53
54 fn add(&mut self, rhs: &Self);
56
57 fn mul(&mut self, rhs: &Scalar);
59}
60
61pub trait Point: Element {
63 fn map(&mut self, dst: DST, message: &[u8]);
65}
66
67#[derive(Clone, Eq, PartialEq)]
78#[repr(transparent)]
79pub struct Scalar(blst_fr);
80
81const SCALAR_LENGTH: usize = 32;
87
88const SCALAR_BITS: usize = 255;
93
94const BLST_FR_ONE: Scalar = Scalar(blst_fr {
107 l: [
108 0x0000_0001_ffff_fffe,
109 0x5884_b7fa_0003_4802,
110 0x998c_4fef_ecbc_4ff5,
111 0x1824_b159_acc5_056f,
112 ],
113});
114
115#[derive(Clone, Copy, Eq, PartialEq)]
117#[repr(transparent)]
118pub struct G1(blst_p1);
119
120pub const G1_ELEMENT_BYTE_LENGTH: usize = 48;
122
123pub const G1_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
125
126pub const G1_MESSAGE: DST = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_POP_";
133
134#[derive(Clone, Copy, Eq, PartialEq)]
136#[repr(transparent)]
137pub struct G2(blst_p2);
138
139pub const G2_ELEMENT_BYTE_LENGTH: usize = 96;
141
142pub const G2_PROOF_OF_POSSESSION: DST = b"BLS_POP_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
144
145pub const G2_MESSAGE: DST = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
152
153#[derive(Debug, Clone, Copy, Eq, PartialEq)]
158pub struct GT(blst_fp12);
159
160pub type Private = Scalar;
162
163pub const PRIVATE_KEY_LENGTH: usize = SCALAR_LENGTH;
165
166pub type Public = G1;
168
169pub const PUBLIC_KEY_LENGTH: usize = G1_ELEMENT_BYTE_LENGTH;
171
172pub type Signature = G2;
174
175pub const SIGNATURE_LENGTH: usize = G2_ELEMENT_BYTE_LENGTH;
177
178pub const PROOF_OF_POSSESSION: DST = G2_PROOF_OF_POSSESSION;
180
181pub const MESSAGE: DST = G2_MESSAGE;
183
184impl Scalar {
185 pub fn rand<R: RngCore>(rng: &mut R) -> Self {
187 let mut ikm = [0u8; 64];
189 rng.fill_bytes(&mut ikm);
190
191 let mut ret = blst_fr::default();
193 unsafe {
194 let mut sc = blst_scalar::default();
195 blst_keygen(&mut sc, ikm.as_ptr(), ikm.len(), ptr::null(), 0);
196 blst_fr_from_scalar(&mut ret, &sc);
197 }
198
199 ikm.zeroize();
201 Self(ret)
202 }
203
204 pub fn set_int(&mut self, i: u32) {
206 let buffer = [i as u64, 0, 0, 0];
211 unsafe { blst_fr_from_uint64(&mut self.0, buffer.as_ptr()) };
212 }
213
214 pub fn inverse(&self) -> Option<Self> {
216 if *self == Self::zero() {
217 return None;
218 }
219 let mut ret = blst_fr::default();
220 unsafe { blst_fr_inverse(&mut ret, &self.0) };
221 Some(Self(ret))
222 }
223
224 pub fn sub(&mut self, rhs: &Self) {
226 unsafe { blst_fr_sub(&mut self.0, &self.0, &rhs.0) }
227 }
228
229 fn as_slice(&self) -> [u8; Self::SIZE] {
231 let mut slice = [0u8; Self::SIZE];
232 unsafe {
233 let mut scalar = blst_scalar::default();
234 blst_scalar_from_fr(&mut scalar, &self.0);
235 blst_bendian_from_scalar(slice.as_mut_ptr(), &scalar);
236 }
237 slice
238 }
239}
240
241impl Element for Scalar {
242 fn zero() -> Self {
243 Self(blst_fr::default())
244 }
245
246 fn one() -> Self {
247 BLST_FR_ONE
248 }
249
250 fn add(&mut self, rhs: &Self) {
251 unsafe {
252 blst_fr_add(&mut self.0, &self.0, &rhs.0);
253 }
254 }
255
256 fn mul(&mut self, rhs: &Self) {
257 unsafe {
258 blst_fr_mul(&mut self.0, &self.0, &rhs.0);
259 }
260 }
261}
262
263impl Write for Scalar {
264 fn write(&self, buf: &mut impl BufMut) {
265 let slice = self.as_slice();
266 buf.put_slice(&slice);
267 }
268}
269
270impl Read for Scalar {
271 type Cfg = ();
272
273 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
274 let bytes = <[u8; Self::SIZE]>::read(buf)?;
275 let mut ret = blst_fr::default();
276 unsafe {
277 let mut scalar = blst_scalar::default();
278 blst_scalar_from_bendian(&mut scalar, bytes.as_ptr());
279 if !blst_sk_check(&scalar) {
289 return Err(Invalid("Scalar", "Invalid"));
290 }
291 blst_fr_from_scalar(&mut ret, &scalar);
292 }
293 Ok(Self(ret))
294 }
295}
296
297impl FixedSize for Scalar {
298 const SIZE: usize = SCALAR_LENGTH;
299}
300
301impl Hash for Scalar {
302 fn hash<H: Hasher>(&self, state: &mut H) {
303 let slice = self.as_slice();
304 state.write(&slice);
305 }
306}
307
308impl Debug for Scalar {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 write!(f, "{}", hex(&self.as_slice()))
311 }
312}
313
314impl Display for Scalar {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 write!(f, "{}", hex(&self.as_slice()))
317 }
318}
319
320impl Zeroize for Scalar {
321 fn zeroize(&mut self) {
322 self.0.l.zeroize();
323 }
324}
325
326impl Drop for Scalar {
327 fn drop(&mut self) {
328 self.zeroize();
329 }
330}
331
332impl ZeroizeOnDrop for Scalar {}
333
334#[derive(Clone, PartialEq, Hash)]
336pub struct Share {
337 pub index: u32,
339 pub private: Private,
341}
342
343impl Share {
344 pub fn public(&self) -> Public {
348 let mut public = <Public as Element>::one();
349 public.mul(&self.private);
350 public
351 }
352}
353
354impl Write for Share {
355 fn write(&self, buf: &mut impl BufMut) {
356 UInt(self.index).write(buf);
357 self.private.write(buf);
358 }
359}
360
361impl Read for Share {
362 type Cfg = ();
363
364 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
365 let index = UInt::read(buf)?.into();
366 let private = Private::read(buf)?;
367 Ok(Self { index, private })
368 }
369}
370
371impl EncodeSize for Share {
372 fn encode_size(&self) -> usize {
373 UInt(self.index).encode_size() + self.private.encode_size()
374 }
375}
376
377impl Display for Share {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 write!(f, "Share(index={}, private={})", self.index, self.private)
380 }
381}
382
383impl Debug for Share {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 write!(f, "Share(index={}, private={})", self.index, self.private)
386 }
387}
388
389impl G1 {
390 fn as_slice(&self) -> [u8; Self::SIZE] {
392 let mut slice = [0u8; Self::SIZE];
393 unsafe {
394 blst_p1_compress(slice.as_mut_ptr(), &self.0);
395 }
396 slice
397 }
398}
399
400impl Element for G1 {
401 fn zero() -> Self {
402 Self(blst_p1::default())
403 }
404
405 fn one() -> Self {
406 let mut ret = blst_p1::default();
407 unsafe {
408 blst_p1_from_affine(&mut ret, &BLS12_381_G1);
409 }
410 Self(ret)
411 }
412
413 fn add(&mut self, rhs: &Self) {
414 unsafe {
415 blst_p1_add_or_double(&mut self.0, &self.0, &rhs.0);
416 }
417 }
418
419 fn mul(&mut self, rhs: &Scalar) {
420 let mut scalar: blst_scalar = blst_scalar::default();
421 unsafe {
422 blst_scalar_from_fr(&mut scalar, &rhs.0);
423 blst_p1_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
426 }
427 }
428}
429
430impl Write for G1 {
431 fn write(&self, buf: &mut impl BufMut) {
432 let slice = self.as_slice();
433 buf.put_slice(&slice);
434 }
435}
436
437impl Read for G1 {
438 type Cfg = ();
439
440 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
441 let bytes = <[u8; Self::SIZE]>::read(buf)?;
442 let mut ret = blst_p1::default();
443 unsafe {
444 let mut affine = blst_p1_affine::default();
445 match blst_p1_uncompress(&mut affine, bytes.as_ptr()) {
446 BLST_ERROR::BLST_SUCCESS => {}
447 BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G1", "Bad encoding")),
448 BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G1", "Not on curve")),
449 BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G1", "Not in group")),
450 BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G1", "Type mismatch")),
451 BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G1", "Verify fail")),
452 BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G1", "PK is Infinity")),
453 BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G1", "Bad scalar")),
454 }
455 blst_p1_from_affine(&mut ret, &affine);
456
457 if blst_p1_is_inf(&ret) {
459 return Err(Invalid("G1", "Infinity"));
460 }
461
462 if !blst_p1_in_g1(&ret) {
464 return Err(Invalid("G1", "Outside G1"));
465 }
466 }
467 Ok(Self(ret))
468 }
469}
470
471impl FixedSize for G1 {
472 const SIZE: usize = G1_ELEMENT_BYTE_LENGTH;
473}
474
475impl Hash for G1 {
476 fn hash<H: Hasher>(&self, state: &mut H) {
477 let slice = self.as_slice();
478 state.write(&slice);
479 }
480}
481
482impl Point for G1 {
483 fn map(&mut self, dst: DST, data: &[u8]) {
484 unsafe {
485 blst_hash_to_g1(
486 &mut self.0,
487 data.as_ptr(),
488 data.len(),
489 dst.as_ptr(),
490 dst.len(),
491 ptr::null(),
492 0,
493 );
494 }
495 }
496}
497
498impl Debug for G1 {
499 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 write!(f, "{}", hex(&self.as_slice()))
501 }
502}
503
504impl Display for G1 {
505 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506 write!(f, "{}", hex(&self.as_slice()))
507 }
508}
509
510impl G2 {
511 fn as_slice(&self) -> [u8; Self::SIZE] {
513 let mut slice = [0u8; Self::SIZE];
514 unsafe {
515 blst_p2_compress(slice.as_mut_ptr(), &self.0);
516 }
517 slice
518 }
519}
520
521impl Element for G2 {
522 fn zero() -> Self {
523 Self(blst_p2::default())
524 }
525
526 fn one() -> Self {
527 let mut ret = blst_p2::default();
528 unsafe {
529 blst_p2_from_affine(&mut ret, &BLS12_381_G2);
530 }
531 Self(ret)
532 }
533
534 fn add(&mut self, rhs: &Self) {
535 unsafe {
536 blst_p2_add_or_double(&mut self.0, &self.0, &rhs.0);
537 }
538 }
539
540 fn mul(&mut self, rhs: &Scalar) {
541 let mut scalar = blst_scalar::default();
542 unsafe {
543 blst_scalar_from_fr(&mut scalar, &rhs.0);
544 blst_p2_mult(&mut self.0, &self.0, scalar.b.as_ptr(), SCALAR_BITS);
547 }
548 }
549}
550
551impl Write for G2 {
552 fn write(&self, buf: &mut impl BufMut) {
553 let slice = self.as_slice();
554 buf.put_slice(&slice);
555 }
556}
557
558impl Read for G2 {
559 type Cfg = ();
560
561 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
562 let bytes = <[u8; Self::SIZE]>::read(buf)?;
563 let mut ret = blst_p2::default();
564 unsafe {
565 let mut affine = blst_p2_affine::default();
566 match blst_p2_uncompress(&mut affine, bytes.as_ptr()) {
567 BLST_ERROR::BLST_SUCCESS => {}
568 BLST_ERROR::BLST_BAD_ENCODING => return Err(Invalid("G2", "Bad encoding")),
569 BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(Invalid("G2", "Not on curve")),
570 BLST_ERROR::BLST_POINT_NOT_IN_GROUP => return Err(Invalid("G2", "Not in group")),
571 BLST_ERROR::BLST_AGGR_TYPE_MISMATCH => return Err(Invalid("G2", "Type mismatch")),
572 BLST_ERROR::BLST_VERIFY_FAIL => return Err(Invalid("G2", "Verify fail")),
573 BLST_ERROR::BLST_PK_IS_INFINITY => return Err(Invalid("G2", "PK is Infinity")),
574 BLST_ERROR::BLST_BAD_SCALAR => return Err(Invalid("G2", "Bad scalar")),
575 }
576 blst_p2_from_affine(&mut ret, &affine);
577
578 if blst_p2_is_inf(&ret) {
580 return Err(Invalid("G2", "Infinity"));
581 }
582
583 if !blst_p2_in_g2(&ret) {
585 return Err(Invalid("G2", "Outside G2"));
586 }
587 }
588 Ok(Self(ret))
589 }
590}
591
592impl FixedSize for G2 {
593 const SIZE: usize = G2_ELEMENT_BYTE_LENGTH;
594}
595
596impl Hash for G2 {
597 fn hash<H: Hasher>(&self, state: &mut H) {
598 let slice = self.as_slice();
599 state.write(&slice);
600 }
601}
602
603impl Point for G2 {
604 fn map(&mut self, dst: DST, data: &[u8]) {
605 unsafe {
606 blst_hash_to_g2(
607 &mut self.0,
608 data.as_ptr(),
609 data.len(),
610 dst.as_ptr(),
611 dst.len(),
612 ptr::null(),
613 0,
614 );
615 }
616 }
617}
618
619impl Debug for G2 {
620 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621 write!(f, "{}", hex(&self.as_slice()))
622 }
623}
624
625impl Display for G2 {
626 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
627 write!(f, "{}", hex(&self.as_slice()))
628 }
629}
630
631pub(super) fn equal(pk: &G1, sig: &G2, hm: &G2) -> bool {
634 let mut pairing = Pairing::new(false, &[]);
638
639 let mut q = blst_p2_affine::default();
641 unsafe {
642 blst_p2_to_affine(&mut q, &sig.0);
643 pairing.raw_aggregate(&q, &BLS12_381_NEG_G1);
644 }
645
646 let mut p = blst_p1_affine::default();
648 let mut q = blst_p2_affine::default();
649 unsafe {
650 blst_p1_to_affine(&mut p, &pk.0);
651 blst_p2_to_affine(&mut q, &hm.0);
652 }
653
654 pairing.raw_aggregate(&q, &p);
656
657 pairing.commit();
662 pairing.finalverify(None)
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668 use commonware_codec::{DecodeExt, Encode};
669 use rand::prelude::*;
670
671 #[test]
672 fn basic_group() {
673 let s = Scalar::rand(&mut thread_rng());
675 let mut e1 = s.clone();
676 let e2 = s.clone();
677 let mut s2 = s.clone();
678 s2.add(&s);
679 s2.mul(&s);
680 e1.add(&e2);
681 e1.mul(&e2);
682
683 let mut p1 = G1::zero();
685 p1.mul(&s2);
686
687 let mut p2 = G1::zero();
689 p2.mul(&s);
690 p2.add(&p2.clone());
691 assert_eq!(p1, p2);
692 }
693
694 #[test]
695 fn test_scalar_codec() {
696 let original = Scalar::rand(&mut thread_rng());
697 let mut encoded = original.encode();
698 assert_eq!(encoded.len(), Scalar::SIZE);
699 let decoded = Scalar::decode(&mut encoded).unwrap();
700 assert_eq!(original, decoded);
701 }
702
703 #[test]
704 fn test_g1_codec() {
705 let mut original = G1::one();
706 original.mul(&Scalar::rand(&mut thread_rng()));
707 let mut encoded = original.encode();
708 assert_eq!(encoded.len(), G1::SIZE);
709 let decoded = G1::decode(&mut encoded).unwrap();
710 assert_eq!(original, decoded);
711 }
712
713 #[test]
714 fn test_g2_codec() {
715 let mut original = G2::one();
716 original.mul(&Scalar::rand(&mut thread_rng()));
717 let mut encoded = original.encode();
718 assert_eq!(encoded.len(), G2::SIZE);
719 let decoded = G2::decode(&mut encoded).unwrap();
720 assert_eq!(original, decoded);
721 }
722}