1use curve25519_entropic::{
2 constants::ED25519_BASEPOINT_TABLE, edwards::CompressedEdwardsY, scalar::Scalar,
3};
4
5use sha2::{Digest, Sha512};
6
7use crate::{errors::VRFError, utils::WEAK_KEYS};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub struct SecretKey {
11 pub(crate) bytes: [u8; 32],
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct PublicKey {
16 pub(crate) point: CompressedEdwardsY,
17}
18
19impl SecretKey {
20 #[must_use]
21 pub fn new(bytes: &[u8; 32]) -> Self {
22 SecretKey { bytes: *bytes }
23 }
24
25 #[must_use]
26 pub fn from_slice(bytes: &[u8]) -> Self {
27 let mut b = [0u8; 32];
28 b.copy_from_slice(bytes);
29 SecretKey { bytes: b }
30 }
31
32 #[must_use]
33 pub fn as_bytes(&self) -> &[u8] {
34 &self.bytes
35 }
36
37 #[must_use]
38 pub fn to_bytes(&self) -> [u8; 32] {
39 self.bytes
40 }
41
42 pub fn extract_public_key_and_scalar(&self) -> Result<(PublicKey, Scalar), VRFError> {
48 let mut hasher = Sha512::new();
49 hasher.update(&self.bytes);
50 let hash: [u8; 64] = hasher.finalize().into();
51 let mut digest: [u8; 32] = [0u8; 32];
52 digest.copy_from_slice(&hash[..32]);
53 digest[0] &= 0xF8;
54 digest[31] &= 0x7F;
55 digest[31] |= 0x40;
56
57 let scalar = Scalar::from_bits(digest);
58
59 let point = &scalar * &ED25519_BASEPOINT_TABLE;
60 if point.mul_by_cofactor().compress() == CompressedEdwardsY::default() {
61 return Err(VRFError::InvalidSecretKey {});
62 }
63 let pk = PublicKey {
64 point: point.compress(),
65 };
66 Ok((pk, scalar))
67 }
68}
69
70impl PublicKey {
71 #[must_use]
72 pub fn new(point: CompressedEdwardsY) -> Self {
73 PublicKey { point }
74 }
75
76 #[must_use]
77 pub fn from_bytes(bytes: &[u8]) -> Self {
78 let mut b = [0u8; 32];
79 b.copy_from_slice(bytes);
80 PublicKey {
81 point: CompressedEdwardsY::from_slice(&b),
82 }
83 }
84
85 #[must_use]
86 pub fn as_bytes(&self) -> &[u8] {
87 self.point.as_bytes()
88 }
89
90 #[must_use]
91 pub fn to_bytes(&self) -> [u8; 32] {
92 self.point.to_bytes()
93 }
94
95 #[must_use]
96 pub fn as_point(&self) -> &CompressedEdwardsY {
97 &self.point
98 }
99
100 pub fn validate(&self) -> Result<(), VRFError> {
106 if self.point.decompress().is_some() {
107 if WEAK_KEYS.contains(self.point.as_bytes()) {
108 return Err(VRFError::InvalidPublicKey {});
109 }
110 } else {
111 return Err(VRFError::InvalidPublicKey {});
112 }
113 Ok(())
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_extract_pk_scalar() {
123 let secret_key =
124 hex::decode("9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60")
125 .unwrap();
126 let public_key =
127 hex::decode("d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a")
128 .unwrap();
129 let secret_scalar =
130 hex::decode("307c83864f2833cb427a2ef1c00a013cfdff2768d980c0a3a520f006904de94f")
131 .unwrap();
132 let secret_key = SecretKey::from_slice(&secret_key);
133 let (pk, scalar) = secret_key.extract_public_key_and_scalar().unwrap();
134 assert_eq!(pk.as_bytes(), public_key.as_slice());
135 assert_eq!(scalar.as_bytes(), secret_scalar.as_slice());
136 }
137}