1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use curve25519_entropic::{
    constants::ED25519_BASEPOINT_TABLE, edwards::CompressedEdwardsY, scalar::Scalar,
};

use sha2::{Digest, Sha512};

use crate::{errors::VRFError, utils::WEAK_KEYS};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SecretKey {
    pub(crate) bytes: [u8; 32],
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PublicKey {
    pub(crate) point: CompressedEdwardsY,
}

impl SecretKey {
    #[must_use]
    pub fn new(bytes: &[u8; 32]) -> Self {
        SecretKey { bytes: *bytes }
    }

    #[must_use]
    pub fn from_slice(bytes: &[u8]) -> Self {
        let mut b = [0u8; 32];
        b.copy_from_slice(bytes);
        SecretKey { bytes: b }
    }

    #[must_use]
    pub fn as_bytes(&self) -> &[u8] {
        &self.bytes
    }

    #[must_use]
    pub fn to_bytes(&self) -> [u8; 32] {
        self.bytes
    }

    /// Extracts the public key and scalar of this [`SecretKey`].
    ///
    /// # Errors
    ///
    /// This function will return an error if the secret key is invalid.
    pub fn extract_public_key_and_scalar(&self) -> Result<(PublicKey, Scalar), VRFError> {
        let mut hasher = Sha512::new();
        hasher.update(&self.bytes);
        let hash: [u8; 64] = hasher.finalize().into();
        let mut digest: [u8; 32] = [0u8; 32];
        digest.copy_from_slice(&hash[..32]);
        digest[0] &= 0xF8;
        digest[31] &= 0x7F;
        digest[31] |= 0x40;

        let scalar = Scalar::from_bits(digest);

        let point = &scalar * &ED25519_BASEPOINT_TABLE;
        if point.mul_by_cofactor().compress() == CompressedEdwardsY::default() {
            return Err(VRFError::InvalidSecretKey {});
        }
        let pk = PublicKey {
            point: point.compress(),
        };
        Ok((pk, scalar))
    }
}

impl PublicKey {
    #[must_use]
    pub fn new(point: CompressedEdwardsY) -> Self {
        PublicKey { point }
    }

    #[must_use]
    pub fn from_bytes(bytes: &[u8]) -> Self {
        let mut b = [0u8; 32];
        b.copy_from_slice(bytes);
        PublicKey {
            point: CompressedEdwardsY::from_slice(&b),
        }
    }

    #[must_use]
    pub fn as_bytes(&self) -> &[u8] {
        self.point.as_bytes()
    }

    #[must_use]
    pub fn to_bytes(&self) -> [u8; 32] {
        self.point.to_bytes()
    }

    #[must_use]
    pub fn as_point(&self) -> &CompressedEdwardsY {
        &self.point
    }

    /// Validates this [`PublicKey`].
    ///
    /// # Errors
    ///
    /// This function will return an error if the public key multiplied by the cofactor is the identity of the curve.
    pub fn validate(&self) -> Result<(), VRFError> {
        if self.point.decompress().is_some() {
            if WEAK_KEYS.contains(self.point.as_bytes()) {
                return Err(VRFError::InvalidPublicKey {});
            }
        } else {
            return Err(VRFError::InvalidPublicKey {});
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_pk_scalar() {
        let secret_key =
            hex::decode("9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60")
                .unwrap();
        let public_key =
            hex::decode("d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a")
                .unwrap();
        let secret_scalar =
            hex::decode("307c83864f2833cb427a2ef1c00a013cfdff2768d980c0a3a520f006904de94f")
                .unwrap();
        let secret_key = SecretKey::from_slice(&secret_key);
        let (pk, scalar) = secret_key.extract_public_key_and_scalar().unwrap();
        assert_eq!(pk.as_bytes(), public_key.as_slice());
        assert_eq!(scalar.as_bytes(), secret_scalar.as_slice());
    }
}