1use crate::{DerivableKey, Error, PublicKey, Result};
2use blst::*;
3use chia_sha2::Sha256;
4use chia_traits::{Streamable, read_bytes};
5use hkdf::HkdfExtract;
6#[cfg(feature = "py-bindings")]
7use pyo3::exceptions::PyNotImplementedError;
8#[cfg(feature = "py-bindings")]
9use pyo3::prelude::*;
10#[cfg(feature = "py-bindings")]
11use pyo3::types::PyType;
12use std::fmt;
13use std::hash::{Hash, Hasher};
14use std::io::Cursor;
15use std::mem::MaybeUninit;
16use std::ops::{Add, AddAssign};
17
18#[cfg_attr(
19 feature = "py-bindings",
20 pyo3::pyclass(frozen, name = "PrivateKey"),
21 derive(chia_py_streamable_macro::PyStreamable)
22)]
23#[derive(PartialEq, Eq, Clone)]
24pub struct SecretKey(pub(crate) blst_scalar);
25
26#[cfg(feature = "arbitrary")]
27impl<'a> arbitrary::Arbitrary<'a> for SecretKey {
28 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
29 let mut seed = [0_u8; 32];
30 u.fill_buffer(seed.as_mut_slice())?;
31 Ok(Self::from_seed(&seed))
32 }
33}
34
35fn flip_bits(input: [u8; 32]) -> [u8; 32] {
36 let mut ret = [0; 32];
37 for i in 0..32 {
38 ret[i] = input[i] ^ 0xff;
39 }
40 ret
41}
42
43fn ikm_to_lamport_sk(ikm: &[u8; 32], salt: [u8; 4]) -> [u8; 255 * 32] {
44 let mut extracter = HkdfExtract::<sha2::Sha256>::new(Some(&salt));
45 extracter.input_ikm(ikm);
46 let (_, h) = extracter.finalize();
47
48 let mut output = [0_u8; 255 * 32];
49 h.expand(&[], &mut output).unwrap();
50 output
51}
52
53fn to_lamport_pk(ikm: [u8; 32], idx: u32) -> [u8; 32] {
54 let not_ikm = flip_bits(ikm);
55 let salt = idx.to_be_bytes();
56
57 let mut lamport0 = ikm_to_lamport_sk(&ikm, salt);
58 let mut lamport1 = ikm_to_lamport_sk(¬_ikm, salt);
59
60 for i in (0..32 * 255).step_by(32) {
61 let hash = sha256(&lamport0[i..i + 32]);
62 lamport0[i..i + 32].copy_from_slice(&hash);
63 }
64 for i in (0..32 * 255).step_by(32) {
65 let hash = sha256(&lamport1[i..i + 32]);
66 lamport1[i..i + 32].copy_from_slice(&hash);
67 }
68
69 let mut hasher = Sha256::new();
70 hasher.update(lamport0);
71 hasher.update(lamport1);
72 hasher.finalize()
73}
74
75fn sha256(bytes: &[u8]) -> [u8; 32] {
76 let mut hasher = Sha256::new();
77 hasher.update(bytes);
78 hasher.finalize()
79}
80
81pub fn is_all_zero(buf: &[u8]) -> bool {
82 let (prefix, aligned, suffix) = unsafe { buf.align_to::<u128>() };
83
84 prefix.iter().all(|&x| x == 0)
85 && suffix.iter().all(|&x| x == 0)
86 && aligned.iter().all(|&x| x == 0)
87}
88
89impl SecretKey {
90 #[must_use]
94 pub fn from_seed(seed: &[u8]) -> Self {
95 assert!(seed.len() >= 32);
98
99 let bytes = unsafe {
100 let mut scalar = MaybeUninit::<blst_scalar>::uninit();
101 blst_keygen_v3(
102 scalar.as_mut_ptr(),
103 seed.as_ptr(),
104 seed.len(),
105 std::ptr::null(),
106 0,
107 );
108 let mut bytes = MaybeUninit::<[u8; 32]>::uninit();
109 blst_bendian_from_scalar(bytes.as_mut_ptr().cast::<u8>(), &scalar.assume_init());
110 bytes.assume_init()
111 };
112 Self::from_bytes(&bytes).expect("from_seed")
113 }
114
115 pub fn from_bytes(bytes: &[u8; 32]) -> Result<Self> {
116 let pk = unsafe {
117 let mut pk = MaybeUninit::<blst_scalar>::uninit();
118 blst_scalar_from_bendian(pk.as_mut_ptr(), bytes.as_ptr());
119 pk.assume_init()
120 };
121
122 if is_all_zero(bytes) {
123 return Ok(Self(pk));
125 }
126
127 if unsafe { !blst_sk_check(&raw const pk) } {
128 return Err(Error::SecretKeyGroupOrder);
129 }
130
131 Ok(Self(pk))
132 }
133
134 pub fn to_bytes(&self) -> [u8; 32] {
135 unsafe {
136 let mut bytes = MaybeUninit::<[u8; 32]>::uninit();
137 blst_bendian_from_scalar(bytes.as_mut_ptr().cast::<u8>(), &raw const self.0);
138 bytes.assume_init()
139 }
140 }
141
142 pub fn public_key(&self) -> PublicKey {
143 let p1 = unsafe {
144 let mut p1 = MaybeUninit::<blst_p1>::uninit();
145 blst_sk_to_pk_in_g1(p1.as_mut_ptr(), &raw const self.0);
146 p1.assume_init()
147 };
148 PublicKey(p1)
149 }
150
151 #[must_use]
152 pub fn derive_hardened(&self, idx: u32) -> SecretKey {
153 SecretKey::from_seed(to_lamport_pk(self.to_bytes(), idx).as_slice())
156 }
157}
158
159impl Streamable for SecretKey {
160 fn update_digest(&self, digest: &mut Sha256) {
161 digest.update(self.to_bytes());
162 }
163
164 fn stream(&self, out: &mut Vec<u8>) -> chia_traits::chia_error::Result<()> {
165 out.extend_from_slice(&self.to_bytes());
166 Ok(())
167 }
168
169 fn parse<const TRUSTED: bool>(
170 input: &mut Cursor<&[u8]>,
171 ) -> chia_traits::chia_error::Result<Self> {
172 Ok(Self::from_bytes(
173 read_bytes(input, 32)?.try_into().unwrap(),
174 )?)
175 }
176}
177
178impl Hash for SecretKey {
179 fn hash<H: Hasher>(&self, state: &mut H) {
180 state.write(&self.to_bytes());
181 }
182}
183
184impl Add<&SecretKey> for &SecretKey {
185 type Output = SecretKey;
186 fn add(self, rhs: &SecretKey) -> SecretKey {
187 let scalar = unsafe {
188 let mut ret = MaybeUninit::<blst_scalar>::uninit();
189 blst_sk_add_n_check(ret.as_mut_ptr(), &raw const self.0, &raw const rhs.0);
190 ret.assume_init()
191 };
192 SecretKey(scalar)
193 }
194}
195
196impl Add<&SecretKey> for SecretKey {
197 type Output = SecretKey;
198 fn add(mut self, rhs: &SecretKey) -> SecretKey {
199 unsafe {
200 blst_sk_add_n_check(&raw mut self.0, &raw const self.0, &raw const rhs.0);
201 self
202 }
203 }
204}
205
206impl AddAssign<&SecretKey> for SecretKey {
207 fn add_assign(&mut self, rhs: &SecretKey) {
208 unsafe {
209 blst_sk_add_n_check(&raw mut self.0, &raw const self.0, &raw const rhs.0);
210 }
211 }
212}
213
214impl fmt::Debug for SecretKey {
215 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
216 formatter.write_fmt(format_args!(
217 "<PrivateKey {}>",
218 &hex::encode(self.to_bytes())
219 ))
220 }
221}
222
223impl DerivableKey for SecretKey {
224 fn derive_unhardened(&self, idx: u32) -> Self {
225 let pk = self.public_key();
226
227 let mut hasher = Sha256::new();
228 hasher.update(pk.to_bytes());
229 hasher.update(idx.to_be_bytes());
230 let digest = hasher.finalize();
231
232 let scalar = unsafe {
233 let mut scalar = MaybeUninit::<blst_scalar>::uninit();
234 let success =
235 blst_scalar_from_be_bytes(scalar.as_mut_ptr(), digest.as_ptr(), digest.len());
236 assert!(success);
237 let success =
238 blst_sk_add_n_check(scalar.as_mut_ptr(), scalar.as_ptr(), &raw const self.0);
239 assert!(success);
240 scalar.assume_init()
241 };
242 Self(scalar)
243 }
244}
245
246#[cfg(feature = "py-bindings")]
247#[pyo3::pymethods]
248impl SecretKey {
249 #[classattr]
250 pub const PRIVATE_KEY_SIZE: usize = 32;
251
252 #[pyo3(signature = (msg, final_pk=None))]
253 pub fn sign(&self, msg: &[u8], final_pk: Option<PublicKey>) -> crate::Signature {
254 match final_pk {
255 Some(prefix) => {
256 let mut aug_msg = prefix.to_bytes().to_vec();
257 aug_msg.extend_from_slice(msg);
258 crate::sign_raw(self, aug_msg)
259 }
260 None => crate::sign(self, msg),
261 }
262 }
263
264 pub fn get_g1(&self) -> PublicKey {
265 self.public_key()
266 }
267
268 #[pyo3(name = "public_key")]
269 pub fn py_public_key(&self) -> PublicKey {
270 self.public_key()
271 }
272
273 pub fn __str__(&self) -> String {
274 hex::encode(self.to_bytes())
275 }
276
277 #[classmethod]
278 #[pyo3(name = "from_parent")]
279 pub fn from_parent(_cls: &Bound<'_, PyType>, _instance: &Self) -> PyResult<Py<PyAny>> {
280 Err(PyNotImplementedError::new_err(
281 "SecretKey does not support from_parent().",
282 ))
283 }
284
285 #[pyo3(name = "derive_hardened")]
286 #[must_use]
287 pub fn py_derive_hardened(&self, idx: u32) -> Self {
288 self.derive_hardened(idx)
289 }
290
291 #[pyo3(name = "derive_unhardened")]
292 #[must_use]
293 pub fn py_derive_unhardened(&self, idx: u32) -> Self {
294 self.derive_unhardened(idx)
295 }
296
297 #[pyo3(name = "from_seed")]
298 #[staticmethod]
299 pub fn py_from_seed(seed: &[u8]) -> Self {
300 Self::from_seed(seed)
301 }
302}
303
304#[cfg(feature = "py-bindings")]
305mod pybindings {
306 use super::*;
307
308 use crate::parse_hex::parse_hex_string;
309
310 use chia_traits::{FromJsonDict, ToJsonDict};
311
312 impl ToJsonDict for SecretKey {
313 fn to_json_dict(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
314 let bytes = self.to_bytes();
315 Ok(("0x".to_string() + &hex::encode(bytes))
316 .into_pyobject(py)?
317 .into_any()
318 .unbind())
319 }
320 }
321
322 impl FromJsonDict for SecretKey {
323 fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
324 Ok(Self::from_bytes(
325 parse_hex_string(o, 32, "PrivateKey")?
326 .as_slice()
327 .try_into()
328 .unwrap(),
329 )?)
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use hex::FromHex;
338 use rand::rngs::StdRng;
339 use rand::{Rng, SeedableRng};
340
341 #[test]
342 fn test_make_key() {
343 let test_cases = &[
349 (
350 "fc795be0c3f18c50dddb34e72179dc597d64055497ecc1e69e2e56a5409651bc139aae8070d4df0ea14d8d2a518a9a00bb1cc6e92e053fe34051f6821df9164c",
351 "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb",
352 ),
353 (
354 "b873212f885ccffbf4692afcb84bc2e55886de2dfa07d90f5c3c239abc31c0a6ce047e30fd8bf6a281e71389aa82d73df74c7bbfb3b06b4639a5cee775cccd3c",
355 "35d65c35d926f62ba2dd128754ddb556edb4e2c926237ab9e02a23e7b3533613",
356 ),
357 (
358 "3e066d7dee2dbf8fcd3fe240a3975658ca118a8f6f4ca81cf99104944604b05a5090a79d99e545704b914ca0397fedb82fd00fd6a72098703709c891a065ee49",
359 "59095c391107936599b7ee6f09067979b321932bd62e23c7f53ed5fb19f851f6",
360 ),
361 ];
362
363 for (seed, sk) in test_cases {
364 assert_eq!(
365 SecretKey::from_seed(&<[u8; 64]>::from_hex(seed).unwrap())
366 .to_bytes()
367 .to_vec(),
368 Vec::<u8>::from_hex(sk).unwrap()
369 );
370 }
371 }
372
373 #[test]
374 fn test_derive_unhardened() {
375 let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
389 let derived_hex = [
390 "399638f99d446500f3c3a363f24c2b0634ad7caf646f503455093f35f29290bd",
391 "3dcb4098ad925d8940e2f516d2d5a4dbab393db928a8c6cb06b93066a09a843a",
392 "13115c8fb68a3d667938dac2ffc6b867a4a0f216bbb228aa43d6bdde14245575",
393 "52e7e9f2fb51f2c5705aea8e11ac82737b95e664ae578f015af22031d956f92b",
394 ];
395 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
396
397 for (i, hex) in derived_hex.iter().enumerate() {
398 let derived = sk.derive_unhardened(i as u32);
399 assert_eq!(derived.to_bytes(), <[u8; 32]>::from_hex(hex).unwrap());
400 }
401 }
402
403 #[test]
404 fn test_public_key() {
405 let test_cases = [
416 (
417 "5aac8405befe4cb3748a67177c56df26355f1f98d979afdb0b2f97858d2f71c3",
418 "b9de000821a610ef644d160c810e35113742ff498002c2deccd8f1a349e423047e9b3fc17ebfc733dbee8fd902ba2961",
419 ),
420 (
421 "23f1fb291d3bd7434282578b842d5ea4785994bb89bd2c94896d1b4be6c70ba2",
422 "96f304a5885e67abdeab5e1ed0576780a1368777ea7760124834529e8694a1837a20ffea107b9769c4f92a1f6c167e69",
423 ),
424 (
425 "2bc1d6d6efe58d365c29ccb7ad12c8457c0eec70a29003073692ac4cb1cd7ba2",
426 "b10568446def64b17fc9b6d614ae036deaac3f2d654e12e45ea04b19208246e0d760e8826426e97f9f0666b7ce340d75",
427 ),
428 (
429 "2bfc8672d859700e30aa6c8edc24a8ce9e6dc53bb1ef936f82de722847d05b9e",
430 "9641472acbd6af7e5313d2500791b87117612af43eef929cf7975aaaa5a203a32698a8ef53763a84d90ad3f00b86ad66",
431 ),
432 (
433 "3311f883dad1e39c52bf82d5870d05371c0b1200576287b5160808f55568151b",
434 "928ea102b5a3e3efe4f4c240d3458a568dfeb505e02901a85ed70a384944b0c08c703a35245322709921b8f2b7f5e54a",
435 ),
436 ];
437
438 for (sk_hex, pk_hex) in test_cases {
439 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
440 let pk = sk.public_key();
441 assert_eq!(
442 pk,
443 PublicKey::from_bytes(&<[u8; 48]>::from_hex(pk_hex).unwrap()).unwrap()
444 );
445 }
446 }
447
448 #[test]
449 fn test_derive_hardened() {
450 let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
464 let derived_hex = [
465 "05eccb2d70e814f51a30d8b9965505605c677afa97228fa2419db583a8121db9",
466 "612ae96bdce2e9bc01693ac579918fbb559e04ec365cce9b66bb80e328f62c46",
467 "5df14a0a34fd6c30a80136d4103f0a93422ce82d5c537bebbecbc56e19fee5b9",
468 "3ea55db88d9a6bf5f1d9c9de072e3c9a56b13f4156d72fca7880cd39b4bd4fdc",
469 ];
470 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
471
472 for (i, hex) in derived_hex.iter().enumerate() {
473 let derived = sk.derive_hardened(i as u32);
474 assert_eq!(derived.to_bytes(), <[u8; 32]>::from_hex(hex).unwrap());
475 }
476 }
477
478 #[test]
479 fn test_debug() {
480 let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
481 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
482 assert_eq!(format!("{sk:?}"), format!("<PrivateKey {sk_hex}>"));
483 }
484
485 #[test]
486 fn test_hash() {
487 fn hash<T: Hash>(v: &T) -> u64 {
488 use std::collections::hash_map::DefaultHasher;
489 let mut h = DefaultHasher::new();
490 v.hash(&mut h);
491 h.finish()
492 }
493
494 let mut rng = StdRng::seed_from_u64(1337);
495 let mut data = [0u8; 32];
496 rng.fill(data.as_mut_slice());
497
498 let sk1 = SecretKey::from_seed(&data);
499 let sk2 = SecretKey::from_seed(&data);
500
501 rng.fill(data.as_mut_slice());
502 let sk3 = SecretKey::from_seed(&data);
503
504 assert!(hash(&sk1) == hash(&sk2));
505 assert!(hash(&sk1) != hash(&sk3));
506 }
507
508 #[test]
509 fn test_from_bytes() {
510 let mut rng = StdRng::seed_from_u64(1337);
511 let mut data = [0u8; 32];
512 for _i in 0..50 {
513 rng.fill(data.as_mut_slice());
514 data[0] |= 0x80;
516 assert_eq!(
518 SecretKey::from_bytes(&data).unwrap_err(),
519 Error::SecretKeyGroupOrder
520 );
521 }
522 }
523
524 #[test]
525 fn test_from_bytes_zero() {
526 let data = [0u8; 32];
527 let _sk = SecretKey::from_bytes(&data).unwrap();
528 }
529
530 #[test]
531 fn test_aggregate_secret_key() {
532 let sk_hex = "5aac8405befe4cb3748a67177c56df26355f1f98d979afdb0b2f97858d2f71c3";
533 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
534 let sk2 = &sk + &sk;
535 let sk3 = &sk + &sk + &sk;
536
537 assert_eq!(
538 sk2,
539 SecretKey::from_bytes(
540 &<[u8; 32]>::from_hex(
541 "416b60b8545f1c1eb5daf626ef0be64717009b2eb2f503b7165f2f0c1a5ee385"
542 )
543 .unwrap()
544 )
545 .unwrap()
546 );
547
548 assert_eq!(
549 sk3,
550 SecretKey::from_bytes(
551 &<[u8; 32]>::from_hex(
552 "282a3d6ae9bfeb89f72b853661c0ed67f8a216c48c705793218ec692a78e5547"
553 )
554 .unwrap()
555 )
556 .unwrap()
557 );
558 }
559
560 #[test]
561 fn test_roundtrip() {
562 let mut rng = StdRng::seed_from_u64(1337);
563 let mut data = [0u8; 32];
564 for _i in 0..50 {
565 rng.fill(data.as_mut_slice());
566 let sk = SecretKey::from_seed(&data);
567 let bytes = sk.to_bytes();
568 let sk2 = SecretKey::from_bytes(&bytes).unwrap();
569 assert_eq!(sk, sk2);
570 assert_eq!(sk.public_key(), sk2.public_key());
571 }
572 }
573}
574
575#[cfg(test)]
576#[cfg(feature = "py-bindings")]
577mod pytests {
578 use super::*;
579 use pyo3::Python;
580 use rand::rngs::StdRng;
581 use rand::{Rng, SeedableRng};
582 use rstest::rstest;
583
584 #[test]
585 fn test_json_dict_roundtrip() {
586 Python::initialize();
587 let mut rng = StdRng::seed_from_u64(1337);
588 let mut data = [0u8; 32];
589 for _i in 0..50 {
590 rng.fill(data.as_mut_slice());
591 let sk = SecretKey::from_seed(&data);
592 Python::attach(|py| {
593 let string = sk.to_json_dict(py).expect("to_json_dict");
594 let py_class = py.get_type::<SecretKey>();
595 let sk2 = SecretKey::from_json_dict(&py_class, py, string.bind(py))
596 .unwrap()
597 .extract(py)
598 .unwrap();
599 assert_eq!(sk, sk2);
600 assert_eq!(sk.public_key(), sk2.public_key());
601 });
602 }
603 }
604
605 #[rstest]
606 #[case(
607 "0x000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e",
608 "PrivateKey, invalid length 31 expected 32"
609 )]
610 #[case(
611 "0x000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f00",
612 "PrivateKey, invalid length 33 expected 32"
613 )]
614 #[case(
615 "000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f00",
616 "PrivateKey, invalid length 33 expected 32"
617 )]
618 #[case(
619 "000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e",
620 "PrivateKey, invalid length 31 expected 32"
621 )]
622 #[case(
623 "0r0102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f",
624 "invalid hex"
625 )]
626 fn test_json_dict(#[case] input: &str, #[case] msg: &str) {
627 Python::initialize();
628 Python::attach(|py| {
629 let py_class = py.get_type::<SecretKey>();
630 let err = SecretKey::from_json_dict(
631 &py_class,
632 py,
633 &input.to_string().into_pyobject(py).unwrap().into_any(),
634 )
635 .unwrap_err();
636 assert_eq!(err.value(py).to_string(), msg.to_string());
637 });
638 }
639}