1use crate::util::*;
2use bs58;
3use curve25519_dalek::constants::{
4 RISTRETTO_BASEPOINT_POINT as G, RISTRETTO_BASEPOINT_TABLE as GT,
5};
6use std::borrow::Borrow;
7use subtle::{ConditionallySelectable, ConstantTimeEq};
8
9#[derive(Clone)]
10pub struct PublicKey(pub(crate) [u8; 32], pub(crate) Point);
11#[derive(Clone)]
12pub struct SecretKey(Scalar, PublicKey);
13value_type!(pub, Value, 32, "value");
14value_type!(pub, Proof, 64, "proof");
15
16impl PublicKey {
17 fn from_bytes(bytes: &[u8; 32]) -> Option<Self> {
18 Some(PublicKey(*bytes, unpack(bytes)?))
19 }
20
21 fn offset(&self, input: &[u8]) -> Scalar {
22 hash_s!(&self.0, input)
23 }
24
25 pub fn is_vrf_valid(&self, input: &impl Borrow<[u8]>, value: &Value, proof: &Proof) -> bool {
26 self.is_valid(input.borrow(), value, proof)
27 }
28
29 #[allow(clippy::arithmetic_side_effects)]
33 fn is_valid(&self, input: &[u8], value: &Value, proof: &Proof) -> bool {
34 let p = unwrap_or_return_false!(unpack(&value.0));
35 let (r, c) = unwrap_or_return_false!(unpack(&proof.0));
36 hash_s!(
37 &self.0,
38 &value.0,
39 vmul2(r + c * self.offset(input), &G, c, &self.1),
40 vmul2(r, &p, c, &G)
41 ) == c
42 }
43}
44
45#[allow(clippy::arithmetic_side_effects)]
50fn basemul(s: Scalar) -> Point {
51 &s * &*GT
52}
53
54fn safe_invert(s: Scalar) -> Scalar {
55 Scalar::conditional_select(&s, &Scalar::ONE, s.ct_eq(&Scalar::ZERO)).invert()
56}
57
58impl SecretKey {
59 pub(crate) fn from_scalar(sk: Scalar) -> Self {
60 let pk = basemul(sk);
61 SecretKey(sk, PublicKey(pk.pack(), pk))
62 }
63
64 fn from_bytes(bytes: &[u8; 32]) -> Option<Self> {
65 Some(Self::from_scalar(unpack(bytes)?))
66 }
67
68 pub fn public_key(&self) -> &PublicKey {
69 &self.1
70 }
71
72 pub fn compute_vrf(&self, input: &impl Borrow<[u8]>) -> Value {
73 self.compute(input.borrow())
74 }
75
76 #[allow(clippy::arithmetic_side_effects)]
80 fn compute(&self, input: &[u8]) -> Value {
81 Value(basemul(safe_invert(self.0 + self.1.offset(input))).pack())
82 }
83
84 pub fn compute_vrf_with_proof(&self, input: &impl Borrow<[u8]>) -> (Value, Proof) {
85 self.compute_with_proof(input.borrow())
86 }
87
88 #[allow(clippy::arithmetic_side_effects)]
92 fn compute_with_proof(&self, input: &[u8]) -> (Value, Proof) {
93 let x = self.0 + self.1.offset(input);
94 let inv = safe_invert(x);
95 let val = basemul(inv).pack();
96 let k = prs!(x);
97 let c = hash_s!(&(self.1).0, &val, basemul(k), basemul(inv * k));
98 (Value(val), Proof((k - c * x, c).pack()))
99 }
100
101 pub fn is_vrf_valid(&self, input: &impl Borrow<[u8]>, value: &Value, proof: &Proof) -> bool {
102 self.1.is_valid(input.borrow(), value, proof)
103 }
104}
105
106macro_rules! traits {
107 ($ty:ident, $l:literal, $bytes:expr, $what:literal) => {
108 eq!($ty, |a, b| a.0 == b.0);
109 common_conversions_fixed!($ty, 32, $bytes, $what);
110
111 impl TryFrom<&[u8; $l]> for $ty {
112 type Error = ();
113 fn try_from(value: &[u8; $l]) -> Result<Self, ()> {
114 Self::from_bytes(value).ok_or(())
115 }
116 }
117 };
118}
119
120traits!(PublicKey, 32, |s| &s.0, "public key");
121traits!(SecretKey, 32, |s| s.0.as_bytes(), "secret key");
122
123#[cfg(feature = "rand")]
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 use secp256k1::rand::rngs::OsRng;
129 use serde::{Deserialize, Serialize};
130 use serde_json::{from_str, to_string};
131
132 fn random_secret_key() -> SecretKey {
133 SecretKey::from_scalar(Scalar::random(&mut OsRng))
134 }
135
136 #[test]
137 fn test_conversion() {
138 let sk = random_secret_key();
139 let sk2 = SecretKey::from_bytes(&sk.clone().into()).unwrap();
140 assert_eq!(sk, sk2);
141 let pk = sk.public_key();
142 let pk2 = sk2.public_key();
143 let pk3 = PublicKey::from_bytes(&pk2.into()).unwrap();
144 assert_eq!(pk, pk2);
145 assert_eq!(pk.clone(), pk3);
146 }
147
148 #[test]
149 fn test_verify() {
150 let sk = random_secret_key();
151 let (val, proof) = sk.compute_vrf_with_proof(b"Test");
152 let val2 = sk.compute_vrf(b"Test");
153 assert_eq!(val, val2);
154 assert!(sk.public_key().is_vrf_valid(b"Test", &val, &proof));
155 assert!(!sk.public_key().is_vrf_valid(b"Tent", &val, &proof));
156 }
157
158 #[test]
159 fn test_different_keys() {
160 let sk = random_secret_key();
161 let sk2 = random_secret_key();
162 assert_ne!(sk, sk2);
163 assert_ne!(Into::<[u8; 32]>::into(sk.clone()), Into::<[u8; 32]>::into(sk2.clone()));
164 let pk = sk.public_key();
165 let pk2 = sk2.public_key();
166 assert_ne!(pk, pk2);
167 assert_ne!(Into::<[u8; 32]>::into(pk), Into::<[u8; 32]>::into(pk2));
168 let (val, proof) = sk.compute_vrf_with_proof(b"Test");
169 let (val2, proof2) = sk2.compute_vrf_with_proof(b"Test");
170 assert_ne!(val, val2);
171 assert_ne!(proof, proof2);
172 assert!(!pk2.is_vrf_valid(b"Test", &val, &proof));
173 assert!(!pk2.is_vrf_valid(b"Test", &val2, &proof));
174 assert!(!pk2.is_vrf_valid(b"Test", &val, &proof2));
175 }
176
177 fn round_trip<T: Serialize + for<'de> Deserialize<'de>>(value: &T) -> T {
178 from_str(to_string(value).unwrap().as_str()).unwrap()
179 }
180
181 #[test]
182 fn test_serialize() {
183 let sk = random_secret_key();
184 let sk2 = round_trip(&sk);
185 assert_eq!(sk, sk2);
186 let (val, proof) = sk.compute_vrf_with_proof(b"Test");
187 let (val2, proof2) = sk2.compute_vrf_with_proof(b"Test");
188 let (val3, proof3) = (round_trip(&val), round_trip(&proof));
189 assert_eq!((val, proof), (val2, proof2));
190 assert_eq!((val, proof), (val3, proof3));
191 let pk = sk.public_key();
192 let pk2 = sk2.public_key();
193 let pk3 = round_trip(pk);
194 assert!(pk.is_vrf_valid(b"Test", &val, &proof));
195 assert!(pk2.is_vrf_valid(b"Test", &val, &proof));
196 assert!(pk3.is_vrf_valid(b"Test", &val, &proof));
197 }
198}