1use crate::{DerivableKey, Error, PublicKey, Result};
2use blst::*;
3use chik_sha2::Sha256;
4use chik_traits::{read_bytes, Streamable};
5use hkdf::HkdfExtract;
6#[cfg(feature = "py-bindings")]
7use pyo3::exceptions::PyNotImplementedError;
8#[cfg(feature = "py-bindings")]
9use pyo3::prelude::*;
10#[cfg(feature = "py-bindings")]
11use pyo3::types::PyType;
12use std::fmt;
13use std::hash::{Hash, Hasher};
14use std::io::Cursor;
15use std::mem::MaybeUninit;
16use std::ops::{Add, AddAssign};
17
18#[cfg_attr(
19 feature = "py-bindings",
20 pyo3::pyclass(frozen, name = "PrivateKey"),
21 derive(chik_py_streamable_macro::PyStreamable)
22)]
23#[derive(PartialEq, Eq, Clone)]
24pub struct SecretKey(pub(crate) blst_scalar);
25
26#[cfg(feature = "arbitrary")]
27impl<'a> arbitrary::Arbitrary<'a> for SecretKey {
28 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
29 let mut seed = [0_u8; 32];
30 let _ = u.fill_buffer(seed.as_mut_slice());
31 Ok(Self::from_seed(&seed))
32 }
33}
34
35fn flip_bits(input: [u8; 32]) -> [u8; 32] {
36 let mut ret = [0; 32];
37 for i in 0..32 {
38 ret[i] = input[i] ^ 0xff;
39 }
40 ret
41}
42
43fn ikm_to_lamport_sk(ikm: &[u8; 32], salt: [u8; 4]) -> [u8; 255 * 32] {
44 let mut extracter = HkdfExtract::<sha2::Sha256>::new(Some(&salt));
45 extracter.input_ikm(ikm);
46 let (_, h) = extracter.finalize();
47
48 let mut output = [0_u8; 255 * 32];
49 h.expand(&[], &mut output).unwrap();
50 output
51}
52
53fn to_lamport_pk(ikm: [u8; 32], idx: u32) -> [u8; 32] {
54 let not_ikm = flip_bits(ikm);
55 let salt = idx.to_be_bytes();
56
57 let mut lamport0 = ikm_to_lamport_sk(&ikm, salt);
58 let mut lamport1 = ikm_to_lamport_sk(¬_ikm, salt);
59
60 for i in (0..32 * 255).step_by(32) {
61 let hash = sha256(&lamport0[i..i + 32]);
62 lamport0[i..i + 32].copy_from_slice(&hash);
63 }
64 for i in (0..32 * 255).step_by(32) {
65 let hash = sha256(&lamport1[i..i + 32]);
66 lamport1[i..i + 32].copy_from_slice(&hash);
67 }
68
69 let mut hasher = Sha256::new();
70 hasher.update(lamport0);
71 hasher.update(lamport1);
72 hasher.finalize()
73}
74
75fn sha256(bytes: &[u8]) -> [u8; 32] {
76 let mut hasher = Sha256::new();
77 hasher.update(bytes);
78 hasher.finalize()
79}
80
81pub fn is_all_zero(buf: &[u8]) -> bool {
82 let (prefix, aligned, suffix) = unsafe { buf.align_to::<u128>() };
83
84 prefix.iter().all(|&x| x == 0)
85 && suffix.iter().all(|&x| x == 0)
86 && aligned.iter().all(|&x| x == 0)
87}
88
89impl SecretKey {
90 #[must_use]
94 pub fn from_seed(seed: &[u8]) -> Self {
95 assert!(seed.len() >= 32);
98
99 let bytes = unsafe {
100 let mut scalar = MaybeUninit::<blst_scalar>::uninit();
101 blst_keygen_v3(
102 scalar.as_mut_ptr(),
103 seed.as_ptr(),
104 seed.len(),
105 std::ptr::null(),
106 0,
107 );
108 let mut bytes = MaybeUninit::<[u8; 32]>::uninit();
109 blst_bendian_from_scalar(bytes.as_mut_ptr().cast::<u8>(), &scalar.assume_init());
110 bytes.assume_init()
111 };
112 Self::from_bytes(&bytes).expect("from_seed")
113 }
114
115 pub fn from_bytes(bytes: &[u8; 32]) -> Result<Self> {
116 let pk = unsafe {
117 let mut pk = MaybeUninit::<blst_scalar>::uninit();
118 blst_scalar_from_bendian(pk.as_mut_ptr(), bytes.as_ptr());
119 pk.assume_init()
120 };
121
122 if is_all_zero(bytes) {
123 return Ok(Self(pk));
125 }
126
127 if unsafe { !blst_sk_check(&pk) } {
128 return Err(Error::SecretKeyGroupOrder);
129 }
130
131 Ok(Self(pk))
132 }
133
134 pub fn to_bytes(&self) -> [u8; 32] {
135 unsafe {
136 let mut bytes = MaybeUninit::<[u8; 32]>::uninit();
137 blst_bendian_from_scalar(bytes.as_mut_ptr().cast::<u8>(), &self.0);
138 bytes.assume_init()
139 }
140 }
141
142 pub fn public_key(&self) -> PublicKey {
143 let p1 = unsafe {
144 let mut p1 = MaybeUninit::<blst_p1>::uninit();
145 blst_sk_to_pk_in_g1(p1.as_mut_ptr(), &self.0);
146 p1.assume_init()
147 };
148 PublicKey(p1)
149 }
150
151 #[must_use]
152 pub fn derive_hardened(&self, idx: u32) -> SecretKey {
153 SecretKey::from_seed(to_lamport_pk(self.to_bytes(), idx).as_slice())
156 }
157}
158
159impl Streamable for SecretKey {
160 fn update_digest(&self, digest: &mut Sha256) {
161 digest.update(self.to_bytes());
162 }
163
164 fn stream(&self, out: &mut Vec<u8>) -> chik_traits::chik_error::Result<()> {
165 out.extend_from_slice(&self.to_bytes());
166 Ok(())
167 }
168
169 fn parse<const TRUSTED: bool>(
170 input: &mut Cursor<&[u8]>,
171 ) -> chik_traits::chik_error::Result<Self> {
172 Ok(Self::from_bytes(
173 read_bytes(input, 32)?.try_into().unwrap(),
174 )?)
175 }
176}
177
178impl Hash for SecretKey {
179 fn hash<H: Hasher>(&self, state: &mut H) {
180 state.write(&self.to_bytes());
181 }
182}
183
184impl Add<&SecretKey> for &SecretKey {
185 type Output = SecretKey;
186 fn add(self, rhs: &SecretKey) -> SecretKey {
187 let scalar = unsafe {
188 let mut ret = MaybeUninit::<blst_scalar>::uninit();
189 blst_sk_add_n_check(ret.as_mut_ptr(), &self.0, &rhs.0);
190 ret.assume_init()
191 };
192 SecretKey(scalar)
193 }
194}
195
196impl Add<&SecretKey> for SecretKey {
197 type Output = SecretKey;
198 fn add(mut self, rhs: &SecretKey) -> SecretKey {
199 unsafe {
200 blst_sk_add_n_check(&mut self.0, &self.0, &rhs.0);
201 self
202 }
203 }
204}
205
206impl AddAssign<&SecretKey> for SecretKey {
207 fn add_assign(&mut self, rhs: &SecretKey) {
208 unsafe {
209 blst_sk_add_n_check(&mut self.0, &self.0, &rhs.0);
210 }
211 }
212}
213
214impl fmt::Debug for SecretKey {
215 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
216 formatter.write_fmt(format_args!(
217 "<PrivateKey {}>",
218 &hex::encode(self.to_bytes())
219 ))
220 }
221}
222
223impl DerivableKey for SecretKey {
224 fn derive_unhardened(&self, idx: u32) -> Self {
225 let pk = self.public_key();
226
227 let mut hasher = Sha256::new();
228 hasher.update(pk.to_bytes());
229 hasher.update(idx.to_be_bytes());
230 let digest = hasher.finalize();
231
232 let scalar = unsafe {
233 let mut scalar = MaybeUninit::<blst_scalar>::uninit();
234 let success =
235 blst_scalar_from_be_bytes(scalar.as_mut_ptr(), digest.as_ptr(), digest.len());
236 assert!(success);
237 let success = blst_sk_add_n_check(scalar.as_mut_ptr(), scalar.as_ptr(), &self.0);
238 assert!(success);
239 scalar.assume_init()
240 };
241 Self(scalar)
242 }
243}
244
245#[cfg(feature = "py-bindings")]
246#[pyo3::pymethods]
247impl SecretKey {
248 #[classattr]
249 pub const PRIVATE_KEY_SIZE: usize = 32;
250
251 #[pyo3(signature = (msg, final_pk=None))]
252 pub fn sign(&self, msg: &[u8], final_pk: Option<PublicKey>) -> crate::Signature {
253 match final_pk {
254 Some(prefix) => {
255 let mut aug_msg = prefix.to_bytes().to_vec();
256 aug_msg.extend_from_slice(msg);
257 crate::sign_raw(self, aug_msg)
258 }
259 None => crate::sign(self, msg),
260 }
261 }
262
263 pub fn get_g1(&self) -> PublicKey {
264 self.public_key()
265 }
266
267 #[pyo3(name = "public_key")]
268 pub fn py_public_key(&self) -> PublicKey {
269 self.public_key()
270 }
271
272 pub fn __str__(&self) -> String {
273 hex::encode(self.to_bytes())
274 }
275
276 #[classmethod]
277 #[pyo3(name = "from_parent")]
278 pub fn from_parent(_cls: &Bound<'_, PyType>, _instance: &Self) -> PyResult<PyObject> {
279 Err(PyNotImplementedError::new_err(
280 "SecretKey does not support from_parent().",
281 ))
282 }
283
284 #[pyo3(name = "derive_hardened")]
285 #[must_use]
286 pub fn py_derive_hardened(&self, idx: u32) -> Self {
287 self.derive_hardened(idx)
288 }
289
290 #[pyo3(name = "derive_unhardened")]
291 #[must_use]
292 pub fn py_derive_unhardened(&self, idx: u32) -> Self {
293 self.derive_unhardened(idx)
294 }
295
296 #[pyo3(name = "from_seed")]
297 #[staticmethod]
298 pub fn py_from_seed(seed: &[u8]) -> Self {
299 Self::from_seed(seed)
300 }
301}
302
303#[cfg(feature = "py-bindings")]
304mod pybindings {
305 use super::*;
306
307 use crate::parse_hex::parse_hex_string;
308
309 use chik_traits::{FromJsonDict, ToJsonDict};
310
311 impl ToJsonDict for SecretKey {
312 fn to_json_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
313 let bytes = self.to_bytes();
314 Ok(("0x".to_string() + &hex::encode(bytes))
315 .into_pyobject(py)?
316 .into_any()
317 .unbind())
318 }
319 }
320
321 impl FromJsonDict for SecretKey {
322 fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
323 Ok(Self::from_bytes(
324 parse_hex_string(o, 32, "PrivateKey")?
325 .as_slice()
326 .try_into()
327 .unwrap(),
328 )?)
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use hex::FromHex;
337 use rand::rngs::StdRng;
338 use rand::{Rng, SeedableRng};
339
340 #[test]
341 fn test_make_key() {
342 let test_cases = &[
348 ("fc795be0c3f18c50dddb34e72179dc597d64055497ecc1e69e2e56a5409651bc139aae8070d4df0ea14d8d2a518a9a00bb1cc6e92e053fe34051f6821df9164c",
349 "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb"),
350 ("b873212f885ccffbf4692afcb84bc2e55886de2dfa07d90f5c3c239abc31c0a6ce047e30fd8bf6a281e71389aa82d73df74c7bbfb3b06b4639a5cee775cccd3c",
351 "35d65c35d926f62ba2dd128754ddb556edb4e2c926237ab9e02a23e7b3533613"),
352 ("3e066d7dee2dbf8fcd3fe240a3975658ca118a8f6f4ca81cf99104944604b05a5090a79d99e545704b914ca0397fedb82fd00fd6a72098703709c891a065ee49",
353 "59095c391107936599b7ee6f09067979b321932bd62e23c7f53ed5fb19f851f6")
354 ];
355
356 for (seed, sk) in test_cases {
357 assert_eq!(
358 SecretKey::from_seed(&<[u8; 64]>::from_hex(seed).unwrap())
359 .to_bytes()
360 .to_vec(),
361 Vec::<u8>::from_hex(sk).unwrap()
362 );
363 }
364 }
365
366 #[test]
367 fn test_derive_unhardened() {
368 let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
382 let derived_hex = [
383 "399638f99d446500f3c3a363f24c2b0634ad7caf646f503455093f35f29290bd",
384 "3dcb4098ad925d8940e2f516d2d5a4dbab393db928a8c6cb06b93066a09a843a",
385 "13115c8fb68a3d667938dac2ffc6b867a4a0f216bbb228aa43d6bdde14245575",
386 "52e7e9f2fb51f2c5705aea8e11ac82737b95e664ae578f015af22031d956f92b",
387 ];
388 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
389
390 for (i, hex) in derived_hex.iter().enumerate() {
391 let derived = sk.derive_unhardened(i as u32);
392 assert_eq!(derived.to_bytes(), <[u8; 32]>::from_hex(hex).unwrap());
393 }
394 }
395
396 #[test]
397 fn test_public_key() {
398 let test_cases = [
409 ("5aac8405befe4cb3748a67177c56df26355f1f98d979afdb0b2f97858d2f71c3",
410 "b9de000821a610ef644d160c810e35113742ff498002c2deccd8f1a349e423047e9b3fc17ebfc733dbee8fd902ba2961"),
411 ("23f1fb291d3bd7434282578b842d5ea4785994bb89bd2c94896d1b4be6c70ba2",
412 "96f304a5885e67abdeab5e1ed0576780a1368777ea7760124834529e8694a1837a20ffea107b9769c4f92a1f6c167e69"),
413 ("2bc1d6d6efe58d365c29ccb7ad12c8457c0eec70a29003073692ac4cb1cd7ba2",
414 "b10568446def64b17fc9b6d614ae036deaac3f2d654e12e45ea04b19208246e0d760e8826426e97f9f0666b7ce340d75"),
415 ("2bfc8672d859700e30aa6c8edc24a8ce9e6dc53bb1ef936f82de722847d05b9e",
416 "9641472acbd6af7e5313d2500791b87117612af43eef929cf7975aaaa5a203a32698a8ef53763a84d90ad3f00b86ad66"),
417 ("3311f883dad1e39c52bf82d5870d05371c0b1200576287b5160808f55568151b",
418 "928ea102b5a3e3efe4f4c240d3458a568dfeb505e02901a85ed70a384944b0c08c703a35245322709921b8f2b7f5e54a"),
419 ];
420
421 for (sk_hex, pk_hex) in test_cases {
422 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
423 let pk = sk.public_key();
424 assert_eq!(
425 pk,
426 PublicKey::from_bytes(&<[u8; 48]>::from_hex(pk_hex).unwrap()).unwrap()
427 );
428 }
429 }
430
431 #[test]
432 fn test_derive_hardened() {
433 let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
447 let derived_hex = [
448 "05eccb2d70e814f51a30d8b9965505605c677afa97228fa2419db583a8121db9",
449 "612ae96bdce2e9bc01693ac579918fbb559e04ec365cce9b66bb80e328f62c46",
450 "5df14a0a34fd6c30a80136d4103f0a93422ce82d5c537bebbecbc56e19fee5b9",
451 "3ea55db88d9a6bf5f1d9c9de072e3c9a56b13f4156d72fca7880cd39b4bd4fdc",
452 ];
453 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
454
455 for (i, hex) in derived_hex.iter().enumerate() {
456 let derived = sk.derive_hardened(i as u32);
457 assert_eq!(derived.to_bytes(), <[u8; 32]>::from_hex(hex).unwrap());
458 }
459 }
460
461 #[test]
462 fn test_debug() {
463 let sk_hex = "52d75c4707e39595b27314547f9723e5530c01198af3fc5849d9a7af65631efb";
464 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
465 assert_eq!(format!("{sk:?}"), format!("<PrivateKey {sk_hex}>"));
466 }
467
468 #[test]
469 fn test_hash() {
470 fn hash<T: Hash>(v: &T) -> u64 {
471 use std::collections::hash_map::DefaultHasher;
472 let mut h = DefaultHasher::new();
473 v.hash(&mut h);
474 h.finish()
475 }
476
477 let mut rng = StdRng::seed_from_u64(1337);
478 let mut data = [0u8; 32];
479 rng.fill(data.as_mut_slice());
480
481 let sk1 = SecretKey::from_seed(&data);
482 let sk2 = SecretKey::from_seed(&data);
483
484 rng.fill(data.as_mut_slice());
485 let sk3 = SecretKey::from_seed(&data);
486
487 assert!(hash(&sk1) == hash(&sk2));
488 assert!(hash(&sk1) != hash(&sk3));
489 }
490
491 #[test]
492 fn test_from_bytes() {
493 let mut rng = StdRng::seed_from_u64(1337);
494 let mut data = [0u8; 32];
495 for _i in 0..50 {
496 rng.fill(data.as_mut_slice());
497 data[0] |= 0x80;
499 assert_eq!(
501 SecretKey::from_bytes(&data).unwrap_err(),
502 Error::SecretKeyGroupOrder
503 );
504 }
505 }
506
507 #[test]
508 fn test_from_bytes_zero() {
509 let data = [0u8; 32];
510 let _sk = SecretKey::from_bytes(&data).unwrap();
511 }
512
513 #[test]
514 fn test_aggregate_secret_key() {
515 let sk_hex = "5aac8405befe4cb3748a67177c56df26355f1f98d979afdb0b2f97858d2f71c3";
516 let sk = SecretKey::from_bytes(&<[u8; 32]>::from_hex(sk_hex).unwrap()).unwrap();
517 let sk2 = &sk + &sk;
518 let sk3 = &sk + &sk + &sk;
519
520 assert_eq!(
521 sk2,
522 SecretKey::from_bytes(
523 &<[u8; 32]>::from_hex(
524 "416b60b8545f1c1eb5daf626ef0be64717009b2eb2f503b7165f2f0c1a5ee385"
525 )
526 .unwrap()
527 )
528 .unwrap()
529 );
530
531 assert_eq!(
532 sk3,
533 SecretKey::from_bytes(
534 &<[u8; 32]>::from_hex(
535 "282a3d6ae9bfeb89f72b853661c0ed67f8a216c48c705793218ec692a78e5547"
536 )
537 .unwrap()
538 )
539 .unwrap()
540 );
541 }
542
543 #[test]
544 fn test_roundtrip() {
545 let mut rng = StdRng::seed_from_u64(1337);
546 let mut data = [0u8; 32];
547 for _i in 0..50 {
548 rng.fill(data.as_mut_slice());
549 let sk = SecretKey::from_seed(&data);
550 let bytes = sk.to_bytes();
551 let sk2 = SecretKey::from_bytes(&bytes).unwrap();
552 assert_eq!(sk, sk2);
553 assert_eq!(sk.public_key(), sk2.public_key());
554 }
555 }
556}
557
558#[cfg(test)]
559#[cfg(feature = "py-bindings")]
560mod pytests {
561 use super::*;
562 use pyo3::Python;
563 use rand::rngs::StdRng;
564 use rand::{Rng, SeedableRng};
565 use rstest::rstest;
566
567 #[test]
568 fn test_json_dict_roundtrip() {
569 pyo3::prepare_freethreaded_python();
570 let mut rng = StdRng::seed_from_u64(1337);
571 let mut data = [0u8; 32];
572 for _i in 0..50 {
573 rng.fill(data.as_mut_slice());
574 let sk = SecretKey::from_seed(&data);
575 Python::with_gil(|py| {
576 let string = sk.to_json_dict(py).expect("to_json_dict");
577 let py_class = py.get_type::<SecretKey>();
578 let sk2 = SecretKey::from_json_dict(&py_class, py, string.bind(py))
579 .unwrap()
580 .extract(py)
581 .unwrap();
582 assert_eq!(sk, sk2);
583 assert_eq!(sk.public_key(), sk2.public_key());
584 });
585 }
586 }
587
588 #[rstest]
589 #[case(
590 "0x000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e",
591 "PrivateKey, invalid length 31 expected 32"
592 )]
593 #[case(
594 "0x000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f00",
595 "PrivateKey, invalid length 33 expected 32"
596 )]
597 #[case(
598 "000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f00",
599 "PrivateKey, invalid length 33 expected 32"
600 )]
601 #[case(
602 "000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e",
603 "PrivateKey, invalid length 31 expected 32"
604 )]
605 #[case(
606 "0r0102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f",
607 "invalid hex"
608 )]
609 fn test_json_dict(#[case] input: &str, #[case] msg: &str) {
610 pyo3::prepare_freethreaded_python();
611 Python::with_gil(|py| {
612 let py_class = py.get_type::<SecretKey>();
613 let err = SecretKey::from_json_dict(
614 &py_class,
615 py,
616 &input.to_string().into_pyobject(py).unwrap().into_any(),
617 )
618 .unwrap_err();
619 assert_eq!(err.value(py).to_string(), msg.to_string());
620 });
621 }
622}