chia_bls/
public_key.rs

1use crate::secret_key::is_all_zero;
2use crate::{DerivableKey, Error, Result};
3
4use blst::*;
5use chia_sha2::Sha256;
6use chia_traits::{Streamable, read_bytes};
7#[cfg(feature = "py-bindings")]
8use pyo3::exceptions::PyNotImplementedError;
9#[cfg(feature = "py-bindings")]
10use pyo3::prelude::*;
11#[cfg(feature = "py-bindings")]
12use pyo3::types::PyType;
13use std::fmt;
14use std::hash::{Hash, Hasher};
15use std::io::Cursor;
16use std::mem::MaybeUninit;
17use std::ops::{Add, AddAssign, Neg, SubAssign};
18
19#[cfg_attr(
20    feature = "py-bindings",
21    pyo3::pyclass(name = "G1Element"),
22    derive(chia_py_streamable_macro::PyStreamable)
23)]
24#[derive(Clone, Copy, Default)]
25pub struct PublicKey(pub(crate) blst_p1);
26
27#[cfg(feature = "arbitrary")]
28impl<'a> arbitrary::Arbitrary<'a> for PublicKey {
29    fn arbitrary(_u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
30        // placeholder
31        Ok(Self::default())
32    }
33}
34
35impl PublicKey {
36    pub fn from_bytes_unchecked(bytes: &[u8; 48]) -> Result<Self> {
37        // check if the element is canonical
38        // the first 3 bits have special meaning
39        let zeros_only = is_all_zero(&bytes[1..]);
40
41        if (bytes[0] & 0xc0) == 0xc0 {
42            // enforce that infinity must be 0xc0000..00
43            if bytes[0] != 0xc0 || !zeros_only {
44                return Err(Error::G1NotCanonical);
45            }
46            // return infinity element (point all zero)
47            return Ok(Self::default());
48        }
49
50        if (bytes[0] & 0xc0) != 0x80 {
51            return Err(Error::G1InfinityInvalidBits);
52        }
53        if zeros_only {
54            return Err(Error::G1InfinityNotZero);
55        }
56
57        let p1 = unsafe {
58            let mut p1_affine = MaybeUninit::<blst_p1_affine>::uninit();
59            let ret = blst_p1_uncompress(p1_affine.as_mut_ptr(), bytes.as_ptr());
60            if ret != BLST_ERROR::BLST_SUCCESS {
61                return Err(Error::InvalidPublicKey(ret));
62            }
63            let mut p1 = MaybeUninit::<blst_p1>::uninit();
64            blst_p1_from_affine(p1.as_mut_ptr(), &p1_affine.assume_init());
65            p1.assume_init()
66        };
67        Ok(Self(p1))
68    }
69
70    pub fn generator() -> Self {
71        let p1 = unsafe { *blst_p1_generator() };
72        Self(p1)
73    }
74
75    // Creates a G1 point by multiplying the generator by the specified scalar.
76    // This is the same as creating a private key from the scalar, and then get
77    // the corresponding public key
78    pub fn from_integer(int_bytes: &[u8]) -> Self {
79        let p1 = unsafe {
80            let mut scalar = MaybeUninit::<blst_scalar>::uninit();
81            blst_scalar_from_be_bytes(scalar.as_mut_ptr(), int_bytes.as_ptr(), int_bytes.len());
82            let mut p1 = MaybeUninit::<blst_p1>::uninit();
83            blst_p1_mult(
84                p1.as_mut_ptr(),
85                blst_p1_generator(),
86                scalar.as_ptr().cast::<u8>(),
87                256,
88            );
89            p1.assume_init()
90        };
91        Self(p1)
92    }
93
94    pub fn from_bytes(bytes: &[u8; 48]) -> Result<Self> {
95        let ret = Self::from_bytes_unchecked(bytes)?;
96        if ret.is_valid() {
97            Ok(ret)
98        } else {
99            Err(Error::InvalidPublicKey(BLST_ERROR::BLST_POINT_NOT_ON_CURVE))
100        }
101    }
102
103    pub fn from_uncompressed(buf: &[u8; 96]) -> Result<Self> {
104        let p1 = unsafe {
105            let mut p1_affine = MaybeUninit::<blst_p1_affine>::uninit();
106            let ret = blst_p1_deserialize(p1_affine.as_mut_ptr(), buf.as_ptr());
107            if ret != BLST_ERROR::BLST_SUCCESS {
108                return Err(Error::InvalidSignature(ret));
109            }
110            let mut p1 = MaybeUninit::<blst_p1>::uninit();
111            blst_p1_from_affine(p1.as_mut_ptr(), &p1_affine.assume_init());
112            p1.assume_init()
113        };
114        Ok(Self(p1))
115    }
116
117    pub fn to_bytes(&self) -> [u8; 48] {
118        unsafe {
119            let mut bytes = MaybeUninit::<[u8; 48]>::uninit();
120            blst_p1_compress(bytes.as_mut_ptr().cast::<u8>(), &raw const self.0);
121            bytes.assume_init()
122        }
123    }
124
125    pub fn is_valid(&self) -> bool {
126        // Infinity was considered a valid G1Element in older Relic versions
127        // For historical compatibililty this behavior is maintained.
128        unsafe { blst_p1_is_inf(&raw const self.0) || blst_p1_in_g1(&raw const self.0) }
129    }
130
131    pub fn is_inf(&self) -> bool {
132        unsafe { blst_p1_is_inf(&raw const self.0) }
133    }
134
135    pub fn negate(&mut self) {
136        unsafe {
137            blst_p1_cneg(&raw mut self.0, true);
138        }
139    }
140
141    pub fn scalar_multiply(&mut self, int_bytes: &[u8]) {
142        unsafe {
143            let mut scalar = MaybeUninit::<blst_scalar>::uninit();
144            blst_scalar_from_be_bytes(scalar.as_mut_ptr(), int_bytes.as_ptr(), int_bytes.len());
145            blst_p1_mult(
146                &raw mut self.0,
147                &raw const self.0,
148                scalar.as_ptr().cast::<u8>(),
149                256,
150            );
151        }
152    }
153
154    pub fn get_fingerprint(&self) -> u32 {
155        let mut hasher = Sha256::new();
156        hasher.update(self.to_bytes());
157        let hash: [u8; 32] = hasher.finalize();
158        u32::from_be_bytes(hash[0..4].try_into().unwrap())
159    }
160}
161
162impl PartialEq for PublicKey {
163    fn eq(&self, other: &Self) -> bool {
164        unsafe { blst_p1_is_equal(&raw const self.0, &raw const other.0) }
165    }
166}
167impl Eq for PublicKey {}
168
169#[cfg(feature = "serde")]
170impl serde::Serialize for PublicKey {
171    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
172    where
173        S: serde::Serializer,
174    {
175        chia_serde::ser_bytes(&self.to_bytes(), serializer, true)
176    }
177}
178
179#[cfg(feature = "serde")]
180impl<'de> serde::Deserialize<'de> for PublicKey {
181    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
182    where
183        D: serde::Deserializer<'de>,
184    {
185        Self::from_bytes(&chia_serde::de_bytes(deserializer)?).map_err(serde::de::Error::custom)
186    }
187}
188
189impl Streamable for PublicKey {
190    fn update_digest(&self, digest: &mut Sha256) {
191        digest.update(self.to_bytes());
192    }
193
194    fn stream(&self, out: &mut Vec<u8>) -> chia_traits::Result<()> {
195        out.extend_from_slice(&self.to_bytes());
196        Ok(())
197    }
198
199    fn parse<const TRUSTED: bool>(input: &mut Cursor<&[u8]>) -> chia_traits::Result<Self> {
200        let input = read_bytes(input, 48)?.try_into().unwrap();
201        if TRUSTED {
202            Ok(Self::from_bytes_unchecked(input)?)
203        } else {
204            Ok(Self::from_bytes(input)?)
205        }
206    }
207}
208
209impl Hash for PublicKey {
210    fn hash<H: Hasher>(&self, state: &mut H) {
211        state.write(&self.to_bytes());
212    }
213}
214
215impl Neg for PublicKey {
216    type Output = PublicKey;
217    fn neg(mut self) -> Self::Output {
218        self.negate();
219        self
220    }
221}
222
223impl Neg for &PublicKey {
224    type Output = PublicKey;
225    fn neg(self) -> Self::Output {
226        let mut ret = *self;
227        ret.negate();
228        ret
229    }
230}
231
232impl AddAssign<&PublicKey> for PublicKey {
233    fn add_assign(&mut self, rhs: &PublicKey) {
234        unsafe {
235            blst_p1_add_or_double(&raw mut self.0, &raw const self.0, &raw const rhs.0);
236        }
237    }
238}
239
240impl SubAssign<&PublicKey> for PublicKey {
241    fn sub_assign(&mut self, rhs: &PublicKey) {
242        unsafe {
243            let mut neg = *rhs;
244            blst_p1_cneg(&raw mut neg.0, true);
245            blst_p1_add_or_double(&raw mut self.0, &raw const self.0, &raw const neg.0);
246        }
247    }
248}
249
250impl Add<&PublicKey> for &PublicKey {
251    type Output = PublicKey;
252    fn add(self, rhs: &PublicKey) -> PublicKey {
253        let p1 = unsafe {
254            let mut ret = MaybeUninit::<blst_p1>::uninit();
255            blst_p1_add_or_double(ret.as_mut_ptr(), &raw const self.0, &raw const rhs.0);
256            ret.assume_init()
257        };
258        PublicKey(p1)
259    }
260}
261
262impl Add<&PublicKey> for PublicKey {
263    type Output = PublicKey;
264    fn add(mut self, rhs: &PublicKey) -> PublicKey {
265        unsafe {
266            blst_p1_add_or_double(&raw mut self.0, &raw const self.0, &raw const rhs.0);
267            self
268        }
269    }
270}
271
272impl fmt::Debug for PublicKey {
273    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
274        formatter.write_fmt(format_args!(
275            "<G1Element {}>",
276            &hex::encode(self.to_bytes())
277        ))
278    }
279}
280
281impl DerivableKey for PublicKey {
282    fn derive_unhardened(&self, idx: u32) -> Self {
283        let mut hasher = Sha256::new();
284        hasher.update(self.to_bytes());
285        hasher.update(idx.to_be_bytes());
286        let digest: [u8; 32] = hasher.finalize();
287
288        let p1 = unsafe {
289            let mut nonce = MaybeUninit::<blst_scalar>::uninit();
290            blst_scalar_from_lendian(nonce.as_mut_ptr(), digest.as_ptr());
291            let mut bte = MaybeUninit::<[u8; 48]>::uninit();
292            blst_bendian_from_scalar(bte.as_mut_ptr().cast::<u8>(), nonce.as_ptr());
293            let mut p1 = MaybeUninit::<blst_p1>::uninit();
294            blst_p1_mult(
295                p1.as_mut_ptr(),
296                blst_p1_generator(),
297                bte.as_ptr().cast::<u8>(),
298                256,
299            );
300            blst_p1_add(p1.as_mut_ptr(), p1.as_mut_ptr(), &raw const self.0);
301            p1.assume_init()
302        };
303        PublicKey(p1)
304    }
305}
306
307pub(crate) const DST: &[u8] = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_AUG_";
308
309pub fn hash_to_g1(msg: &[u8]) -> PublicKey {
310    hash_to_g1_with_dst(msg, DST)
311}
312
313pub fn hash_to_g1_with_dst(msg: &[u8], dst: &[u8]) -> PublicKey {
314    let p1 = unsafe {
315        let mut p1 = MaybeUninit::<blst_p1>::uninit();
316        blst_hash_to_g1(
317            p1.as_mut_ptr(),
318            msg.as_ptr(),
319            msg.len(),
320            dst.as_ptr(),
321            dst.len(),
322            std::ptr::null(),
323            0,
324        );
325        p1.assume_init()
326    };
327    PublicKey(p1)
328}
329
330#[cfg(feature = "py-bindings")]
331#[pyo3::pymethods]
332impl PublicKey {
333    #[classattr]
334    pub const SIZE: usize = 48;
335
336    #[new]
337    pub fn init() -> Self {
338        Self::default()
339    }
340
341    #[staticmethod]
342    #[pyo3(name = "generator")]
343    pub fn py_generator() -> Self {
344        Self::generator()
345    }
346
347    pub fn verify(&self, signature: &crate::Signature, msg: &[u8]) -> bool {
348        crate::verify(signature, self, msg)
349    }
350
351    pub fn pair(&self, other: &crate::Signature) -> crate::GTElement {
352        other.pair(self)
353    }
354
355    #[classmethod]
356    #[pyo3(name = "from_parent")]
357    pub fn from_parent(_cls: &Bound<'_, PyType>, _instance: &Self) -> PyResult<Py<PyAny>> {
358        Err(PyNotImplementedError::new_err(
359            "PublicKey does not support from_parent().",
360        ))
361    }
362
363    #[pyo3(name = "get_fingerprint")]
364    pub fn py_get_fingerprint(&self) -> u32 {
365        self.get_fingerprint()
366    }
367
368    #[pyo3(name = "derive_unhardened")]
369    #[must_use]
370    pub fn py_derive_unhardened(&self, idx: u32) -> Self {
371        self.derive_unhardened(idx)
372    }
373
374    pub fn __str__(&self) -> String {
375        hex::encode(self.to_bytes())
376    }
377
378    #[must_use]
379    pub fn __add__(&self, rhs: &Self) -> Self {
380        self + rhs
381    }
382
383    pub fn __iadd__(&mut self, rhs: &Self) {
384        *self += rhs;
385    }
386}
387
388#[cfg(feature = "py-bindings")]
389mod pybindings {
390    use super::*;
391
392    use crate::parse_hex::parse_hex_string;
393
394    use chia_traits::{FromJsonDict, ToJsonDict};
395
396    impl ToJsonDict for PublicKey {
397        fn to_json_dict(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
398            let bytes = self.to_bytes();
399            Ok(("0x".to_string() + &hex::encode(bytes))
400                .into_pyobject(py)?
401                .into_any()
402                .unbind())
403        }
404    }
405
406    impl FromJsonDict for PublicKey {
407        fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
408            Ok(Self::from_bytes(
409                parse_hex_string(o, 48, "PublicKey")?
410                    .as_slice()
411                    .try_into()
412                    .unwrap(),
413            )?)
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::SecretKey;
422    use hex::FromHex;
423    use rand::rngs::StdRng;
424    use rand::{Rng, SeedableRng};
425    use rstest::rstest;
426
427    #[test]
428    fn test_derive_unhardened() {
429        let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
430        let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
431        let pk = sk.public_key();
432
433        // make sure deriving the secret keys produce the same public keys as
434        // deriving the public key
435        for idx in 0..4_usize {
436            let derived_sk = sk.derive_unhardened(idx as u32);
437            let derived_pk = pk.derive_unhardened(idx as u32);
438            assert_eq!(derived_pk.to_bytes(), derived_sk.public_key().to_bytes());
439        }
440    }
441
442    #[test]
443    fn test_from_bytes() {
444        let mut rng = StdRng::seed_from_u64(1337);
445        let mut data = [0u8; 48];
446        for _i in 0..50 {
447            rng.fill(data.as_mut_slice());
448            // clear the bits that mean infinity
449            data[0] = 0x80;
450            // just any random bytes are not a valid key and should fail
451            match PublicKey::from_bytes(&data) {
452                Err(Error::InvalidPublicKey(err)) => {
453                    assert!(
454                        [
455                            BLST_ERROR::BLST_BAD_ENCODING,
456                            BLST_ERROR::BLST_POINT_NOT_ON_CURVE
457                        ]
458                        .contains(&err)
459                    );
460                }
461                Err(e) => {
462                    panic!("unexpected error from_bytes(): {e}");
463                }
464                Ok(v) => {
465                    panic!("unexpected value from_bytes(): {v:?}");
466                }
467            }
468        }
469    }
470
471    #[rstest]
472    #[case(
473        "c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001",
474        Error::G1NotCanonical
475    )]
476    #[case(
477        "c08000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
478        Error::G1NotCanonical
479    )]
480    #[case(
481        "c80000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
482        Error::G1NotCanonical
483    )]
484    #[case(
485        "e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
486        Error::G1NotCanonical
487    )]
488    #[case(
489        "d00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
490        Error::G1NotCanonical
491    )]
492    #[case(
493        "800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
494        Error::G1InfinityNotZero
495    )]
496    #[case(
497        "400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
498        Error::G1InfinityInvalidBits
499    )]
500    fn test_from_bytes_failures(#[case] input: &str, #[case()] error: Error) {
501        let bytes: [u8; 48] = hex::decode(input).unwrap().try_into().unwrap();
502        assert_eq!(PublicKey::from_bytes(&bytes).unwrap_err(), error);
503    }
504
505    #[test]
506    fn test_from_bytes_infinity() {
507        let bytes: [u8; 48] = hex::decode("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap().try_into().unwrap();
508        let pk = PublicKey::from_bytes(&bytes).unwrap();
509        assert_eq!(pk, PublicKey::default());
510    }
511
512    #[test]
513    fn test_get_fingerprint() {
514        let bytes: [u8; 48] = hex::decode("997cc43ed8788f841fcf3071f6f212b89ba494b6ebaf1bda88c3f9de9d968a61f3b7284a5ee13889399ca71a026549a2")
515        .unwrap()
516        .as_slice()
517        .try_into()
518        .unwrap();
519        let pk = PublicKey::from_bytes(&bytes).unwrap();
520        assert_eq!(pk.get_fingerprint(), 651_010_559);
521    }
522
523    #[test]
524    fn test_aggregate_pubkey() {
525        // from blspy import PrivateKey
526        // from blspy import AugSchemeMPL
527        // sk = PrivateKey.from_bytes(bytes.fromhex("52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb"))
528        // pk = sk.get_g1()
529        // pk + pk
530        // <G1Element b1b8033286299e7f238aede0d3fea48d133a1e233139085f72c102c2e6cc1f8a4ea64ed2838c10bbd2ef8f78ef271bf3>
531        // pk + pk + pk
532        // <G1Element a8bc2047d90c04a12e8c38050ec0feb4417b4d5689165cd2cea8a7903aad1778e36548a46d427b5ec571364515e456d6>
533
534        let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
535        let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
536        let pk = sk.public_key();
537        let pk2 = &pk + &pk;
538        let pk3 = &pk + &pk + &pk;
539
540        assert_eq!(pk2, PublicKey::from_bytes(&<[u8; 48]>::from_hex("b1b8033286299e7f238aede0d3fea48d133a1e233139085f72c102c2e6cc1f8a4ea64ed2838c10bbd2ef8f78ef271bf3").unwrap()).unwrap());
541        assert_eq!(pk3, PublicKey::from_bytes(&<[u8; 48]>::from_hex("a8bc2047d90c04a12e8c38050ec0feb4417b4d5689165cd2cea8a7903aad1778e36548a46d427b5ec571364515e456d6").unwrap()).unwrap());
542    }
543
544    #[test]
545    fn test_roundtrip() {
546        let mut rng = StdRng::seed_from_u64(1337);
547        let mut data = [0u8; 32];
548        for _i in 0..50 {
549            rng.fill(data.as_mut_slice());
550            let sk = SecretKey::from_seed(&data);
551            let pk = sk.public_key();
552            let bytes = pk.to_bytes();
553            let pk2 = PublicKey::from_bytes(&bytes).unwrap();
554            assert_eq!(pk, pk2);
555        }
556    }
557
558    #[test]
559    fn test_default_is_valid() {
560        let pk = PublicKey::default();
561        assert!(pk.is_valid());
562    }
563
564    #[test]
565    fn test_infinity_is_valid() {
566        let mut data = [0u8; 48];
567        data[0] = 0xc0;
568        let pk = PublicKey::from_bytes(&data).unwrap();
569        assert!(pk.is_valid());
570    }
571
572    #[test]
573    fn test_is_valid() {
574        let mut rng = StdRng::seed_from_u64(1337);
575        let mut data = [0u8; 32];
576        for _i in 0..50 {
577            rng.fill(data.as_mut_slice());
578            let sk = SecretKey::from_seed(&data);
579            let pk = sk.public_key();
580            assert!(pk.is_valid());
581        }
582    }
583
584    #[test]
585    fn test_default_is_inf() {
586        let pk = PublicKey::default();
587        assert!(pk.is_inf());
588    }
589
590    #[test]
591    fn test_infinity() {
592        let mut data = [0u8; 48];
593        data[0] = 0xc0;
594        let pk = PublicKey::from_bytes(&data).unwrap();
595        assert!(pk.is_inf());
596    }
597
598    #[test]
599    fn test_is_inf() {
600        let mut rng = StdRng::seed_from_u64(1337);
601        let mut data = [0u8; 32];
602        for _i in 0..500 {
603            rng.fill(data.as_mut_slice());
604            let sk = SecretKey::from_seed(&data);
605            let pk = sk.public_key();
606            assert!(!pk.is_inf());
607        }
608    }
609
610    #[test]
611    fn test_hash() {
612        fn hash<T: Hash>(v: T) -> u64 {
613            use std::collections::hash_map::DefaultHasher;
614            let mut h = DefaultHasher::new();
615            v.hash(&mut h);
616            h.finish()
617        }
618
619        let mut rng = StdRng::seed_from_u64(1337);
620        let mut data = [0u8; 32];
621        rng.fill(data.as_mut_slice());
622
623        let sk = SecretKey::from_seed(&data);
624        let pk1 = sk.public_key();
625        let pk2 = pk1.derive_unhardened(1);
626        let pk3 = pk1.derive_unhardened(2);
627
628        assert!(hash(pk2) != hash(pk3));
629        assert!(hash(pk1.derive_unhardened(42)) == hash(pk1.derive_unhardened(42)));
630    }
631
632    #[test]
633    fn test_debug() {
634        let mut data = [0u8; 48];
635        data[0] = 0xc0;
636        let pk = PublicKey::from_bytes(&data).unwrap();
637        assert_eq!(
638            format!("{pk:?}"),
639            format!("<G1Element {}>", hex::encode(data))
640        );
641    }
642
643    #[test]
644    fn test_generator() {
645        assert_eq!(
646            hex::encode(PublicKey::generator().to_bytes()),
647            "97f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb"
648        );
649    }
650
651    #[test]
652    fn test_from_integer() {
653        let mut rng = StdRng::seed_from_u64(1337);
654        let mut data = [0u8; 32];
655        for _i in 0..50 {
656            // this integer may not exceed the group order, so leave the top
657            // byte as 0
658            rng.fill(&mut data[1..]);
659
660            let g1 = PublicKey::from_integer(&data);
661            let expected_g1 = SecretKey::from_bytes(&data)
662                .expect("invalid public key")
663                .public_key();
664            assert_eq!(g1, expected_g1);
665        }
666    }
667
668    // test cases from zksnark test in chia_rs
669    #[rstest]
670    #[case(
671        "06f6ba2972ab1c83718d747b2d55cca96d08729b1ea5a3ab3479b8efe2d455885abf65f58d1507d7f260cd2a4687db821171c9d8dc5c0f5c3c4fd64b26cf93ff28b2e683c409fb374c4e26cc548c6f7cef891e60b55e6115bb38bbe97822e4d4",
672        "a6f6ba2972ab1c83718d747b2d55cca96d08729b1ea5a3ab3479b8efe2d455885abf65f58d1507d7f260cd2a4687db82"
673    )]
674    #[case(
675        "127271e81a1cb5c08a68694fcd5bd52f475d545edd4fbd49b9f6ec402ee1973f9f4102bf3bfccdcbf1b2f862af89a1340d40795c1c09d1e10b1acfa0f3a97a71bf29c11665743fa8d30e57e450b8762959571d6f6d253b236931b93cf634e7cf",
676        "b27271e81a1cb5c08a68694fcd5bd52f475d545edd4fbd49b9f6ec402ee1973f9f4102bf3bfccdcbf1b2f862af89a134"
677    )]
678    #[case(
679        "0fe94ac2d68d39d9207ea0cae4bb2177f7352bd754173ed27bd13b4c156f77f8885458886ee9fbd212719f27a96397c110fa7b4f898b1c45c2e82c5d46b52bdad95cae8299d4fd4556ae02baf20a5ec989fc62f28c8b6b3df6dc696f2afb6e20",
680        "afe94ac2d68d39d9207ea0cae4bb2177f7352bd754173ed27bd13b4c156f77f8885458886ee9fbd212719f27a96397c1"
681    )]
682    #[case(
683        "13aedc305adfdbc854aa105c41085618484858e6baa276b176fd89415021f7a0c75ff4f9ec39f482f142f1b54c11144815e519df6f71b1db46c83b1d2bdf381fc974059f3ccd87ed5259221dc37c50c3be407b58990d14b6d5bb79dad9ab8c42",
684        "b3aedc305adfdbc854aa105c41085618484858e6baa276b176fd89415021f7a0c75ff4f9ec39f482f142f1b54c111448"
685    )]
686    fn test_from_uncompressed(#[case] input: &str, #[case] expect: &str) {
687        let input = hex::decode(input).unwrap();
688        let g1 = PublicKey::from_uncompressed(input.as_slice().try_into().unwrap()).unwrap();
689        let compressed = g1.to_bytes();
690        assert_eq!(hex::encode(compressed), expect);
691    }
692
693    #[test]
694    fn test_negate_roundtrip() {
695        let mut rng = StdRng::seed_from_u64(1337);
696        let mut data = [0u8; 32];
697        for _i in 0..50 {
698            // this integer may not exceed the group order, so leave the top
699            // byte as 0
700            rng.fill(&mut data[1..]);
701
702            let g1 = PublicKey::from_integer(&data);
703            let mut g1_neg = g1;
704            g1_neg.negate();
705            assert!(g1_neg != g1);
706
707            g1_neg.negate();
708            assert!(g1_neg == g1);
709        }
710    }
711
712    #[test]
713    fn test_negate_infinity() {
714        let g1 = PublicKey::default();
715        let mut g1_neg = g1;
716        // negate on infinity is a no-op
717        g1_neg.negate();
718        assert!(g1_neg == g1);
719    }
720
721    #[test]
722    fn test_negate() {
723        let mut rng = StdRng::seed_from_u64(1337);
724        let mut data = [0u8; 32];
725        for _i in 0..50 {
726            // this integer may not exceed the group order, so leave the top
727            // byte as 0
728            rng.fill(&mut data[1..]);
729
730            let g1 = PublicKey::from_integer(&data);
731            let mut g1_neg = g1;
732            g1_neg.negate();
733
734            let mut g1_double = g1;
735            // adding the negative undoes adding the positive
736            g1_double += &g1;
737            assert!(g1_double != g1);
738            g1_double += &g1_neg;
739            assert!(g1_double == g1);
740        }
741    }
742
743    #[test]
744    fn test_scalar_multiply() {
745        let mut rng = StdRng::seed_from_u64(1337);
746        let mut data = [0u8; 32];
747        for _i in 0..50 {
748            // this integer may not exceed the group order, so leave the top
749            // byte as 0
750            rng.fill(&mut data[1..]);
751
752            let mut g1 = PublicKey::from_integer(&data);
753            let mut g1_double = g1;
754            g1_double += &g1;
755            assert!(g1_double != g1);
756            // scalar multiply by 2 is the same as adding oneself
757            g1.scalar_multiply(&[2]);
758            assert!(g1_double == g1);
759        }
760    }
761
762    #[test]
763    fn test_hash_to_g1_different_dst() {
764        const DEFAULT_DST: &[u8] = b"BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_AUG_";
765        const CUSTOM_DST: &[u8] = b"foobar";
766
767        let mut rng = StdRng::seed_from_u64(1337);
768        let mut msg = [0u8; 32];
769        for _i in 0..50 {
770            rng.fill(&mut msg);
771            let default_hash = hash_to_g1(&msg);
772            assert_eq!(default_hash, hash_to_g1_with_dst(&msg, DEFAULT_DST));
773            assert!(default_hash != hash_to_g1_with_dst(&msg, CUSTOM_DST));
774        }
775    }
776
777    // test cases from clvm_rs
778    #[rstest]
779    #[case(
780        "abcdef0123456789",
781        "88e7302bf1fa8fcdecfb96f6b81475c3564d3bcaf552ccb338b1c48b9ba18ab7195c5067fe94fb216478188c0a3bef4a"
782    )]
783    fn test_hash_to_g1(#[case] input: &str, #[case] expect: &str) {
784        let g1 = hash_to_g1(input.as_bytes());
785        assert_eq!(hex::encode(g1.to_bytes()), expect);
786    }
787
788    // test cases from clvm_rs
789    #[rstest]
790    #[case(
791        "abcdef0123456789",
792        "BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_NUL_",
793        "8dd8e3a9197ddefdc25dde980d219004d6aa130d1af9b1808f8b2b004ae94484ac62a08a739ec7843388019a79c437b0"
794    )]
795    #[case(
796        "abcdef0123456789",
797        "BLS_SIG_BLS12381G1_XMD:SHA-256_SSWU_RO_AUG_",
798        "88e7302bf1fa8fcdecfb96f6b81475c3564d3bcaf552ccb338b1c48b9ba18ab7195c5067fe94fb216478188c0a3bef4a"
799    )]
800    fn test_hash_to_g1_with_dst(#[case] input: &str, #[case] dst: &str, #[case] expect: &str) {
801        let g1 = hash_to_g1_with_dst(input.as_bytes(), dst.as_bytes());
802        assert_eq!(hex::encode(g1.to_bytes()), expect);
803    }
804}
805
806#[cfg(test)]
807#[cfg(feature = "py-bindings")]
808mod pytests {
809    use super::*;
810    use crate::SecretKey;
811    use pyo3::Python;
812    use rand::rngs::StdRng;
813    use rand::{Rng, SeedableRng};
814    use rstest::rstest;
815
816    #[test]
817    fn test_json_dict_roundtrip() {
818        Python::initialize();
819        let mut rng = StdRng::seed_from_u64(1337);
820        let mut data = [0u8; 32];
821        for _i in 0..50 {
822            rng.fill(data.as_mut_slice());
823            let sk = SecretKey::from_seed(&data);
824            let pk = sk.public_key();
825            Python::attach(|py| {
826                let string = pk.to_json_dict(py).expect("to_json_dict");
827                let py_class = py.get_type::<PublicKey>();
828                let pk2: PublicKey = PublicKey::from_json_dict(&py_class, py, string.bind(py))
829                    .unwrap()
830                    .extract(py)
831                    .unwrap();
832                assert_eq!(pk, pk2);
833            });
834        }
835    }
836
837    #[rstest]
838    #[case(
839        "0x000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e",
840        "PublicKey, invalid length 47 expected 48"
841    )]
842    #[case(
843        "0x000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f00",
844        "PublicKey, invalid length 49 expected 48"
845    )]
846    #[case(
847        "000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e",
848        "PublicKey, invalid length 47 expected 48"
849    )]
850    #[case(
851        "000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f00",
852        "PublicKey, invalid length 49 expected 48"
853    )]
854    #[case(
855        "0x00r102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f",
856        "invalid hex"
857    )]
858    fn test_json_dict(#[case] input: &str, #[case] msg: &str) {
859        Python::initialize();
860        Python::attach(|py| {
861            let py_class = py.get_type::<PublicKey>();
862            let err = PublicKey::from_json_dict(
863                &py_class,
864                py,
865                &input.to_string().into_pyobject(py).unwrap().into_any(),
866            )
867            .unwrap_err();
868            assert_eq!(err.value(py).to_string(), msg.to_string());
869        });
870    }
871}