1use merlin::Transcript;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use core::iter;
8
9use super::{lagrange_coefficients, Error, Params, PublicPolynomial};
10
11use crate::{
12 alloc::Vec,
13 group::Group,
14 proofs::{LogEqualityProof, ProofOfPossession, TranscriptForGroup, VerificationError},
15 CandidateDecryption, Ciphertext, PublicKey, VerifiableDecryption,
16};
17
18#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "serde", serde(bound = ""))]
23pub struct PublicKeySet<G: Group> {
24 params: Params,
25 shared_key: PublicKey<G>,
26 participant_keys: Vec<PublicKey<G>>,
27}
28
29impl<G: Group> PublicKeySet<G> {
30 pub(crate) fn validate(
31 params: Params,
32 public_polynomial: &[G::Element],
33 proof_of_possession: &ProofOfPossession<G>,
34 ) -> Result<(), Error> {
35 if public_polynomial.len() != params.threshold {
36 return Err(Error::MalformedDealerPolynomial);
37 }
38
39 let mut transcript = Transcript::new(b"elgamal_share_poly");
40 transcript.append_u64(b"n", params.shares as u64);
41 transcript.append_u64(b"t", params.threshold as u64);
42
43 let public_poly_keys: Vec<_> = public_polynomial
44 .iter()
45 .copied()
46 .map(PublicKey::from_element)
47 .collect();
48 proof_of_possession
49 .verify(public_poly_keys.iter(), &mut transcript)
50 .map_err(Error::InvalidDealerProof)?;
51 Ok(())
52 }
53
54 pub fn new(
62 params: Params,
63 public_polynomial: Vec<G::Element>,
64 proof_of_possession: &ProofOfPossession<G>,
65 ) -> Result<Self, Error> {
66 Self::validate(params, &public_polynomial, proof_of_possession)?;
67
68 let public_poly = PublicPolynomial::<G>(public_polynomial);
69 let shared_key = PublicKey::from_element(public_poly.value_at_zero());
70 let participant_keys = (0..params.shares)
71 .map(|idx| PublicKey::from_element(public_poly.value_at((idx as u64 + 1).into())))
72 .collect();
73
74 Ok(Self {
75 params,
76 shared_key,
77 participant_keys,
78 })
79 }
80
81 pub fn from_participants(
89 params: Params,
90 participant_keys: Vec<PublicKey<G>>,
91 ) -> Result<Self, Error> {
92 if params.shares != participant_keys.len() {
93 return Err(Error::ParticipantCountMismatch);
94 }
95
96 let indexes: Vec<_> = (0..params.threshold).collect();
98 let (denominators, scale) = lagrange_coefficients::<G>(&indexes);
99 let starting_keys = participant_keys
100 .iter()
101 .map(PublicKey::as_element)
102 .take(params.threshold);
103 let shared_key = G::vartime_multi_mul(&denominators, starting_keys.clone());
104 let shared_key = PublicKey::from_element(shared_key * &scale);
105
106 let mut inverses: Vec<_> = (1_u64..=params.shares as u64)
110 .map(G::Scalar::from)
111 .collect();
112 G::invert_scalars(&mut inverses);
113
114 for (x, key) in participant_keys.iter().enumerate().skip(params.threshold) {
115 let mut key_scale = indexes
116 .iter()
117 .map(|&idx| G::Scalar::from((x - idx) as u64))
118 .fold(G::Scalar::from(1), |acc, value| acc * value);
119
120 let key_denominators: Vec<_> = denominators
121 .iter()
122 .enumerate()
123 .map(|(idx, &d)| d * G::Scalar::from(idx as u64 + 1) * inverses[x - idx - 1])
124 .collect();
125
126 if params.threshold % 2 == 0 {
130 key_scale = -key_scale;
131 }
132
133 let interpolated_key = G::vartime_multi_mul(&key_denominators, starting_keys.clone());
134 let interpolated_key = interpolated_key * &key_scale;
135 if interpolated_key != key.as_element() {
136 return Err(Error::MalformedParticipantKeys);
137 }
138 }
139
140 Ok(Self {
141 params,
142 shared_key,
143 participant_keys,
144 })
145 }
146
147 pub fn params(&self) -> Params {
149 self.params
150 }
151
152 pub fn shared_key(&self) -> &PublicKey<G> {
154 &self.shared_key
155 }
156
157 pub fn participant_key(&self, index: usize) -> Option<&PublicKey<G>> {
160 self.participant_keys.get(index)
161 }
162
163 pub fn participant_keys(&self) -> &[PublicKey<G>] {
165 &self.participant_keys
166 }
167
168 pub(super) fn commit(&self, transcript: &mut Transcript) {
169 transcript.append_u64(b"n", self.params.shares as u64);
170 transcript.append_u64(b"t", self.params.threshold as u64);
171 transcript.append_element_bytes(b"K", self.shared_key.as_bytes());
172 }
173
174 pub fn verify_participant(
188 &self,
189 index: usize,
190 proof: &ProofOfPossession<G>,
191 ) -> Result<(), VerificationError> {
192 let participant_key = self.participant_key(index).unwrap_or_else(|| {
193 panic!(
194 "participant index {index} out of bounds, expected a value in 0..{}",
195 self.participant_keys.len()
196 );
197 });
198 let mut transcript = Transcript::new(b"elgamal_participant_pop");
199 self.commit(&mut transcript);
200 transcript.append_u64(b"i", index as u64);
201 proof.verify(iter::once(participant_key), &mut transcript)
202 }
203
204 pub fn verify_share(
211 &self,
212 candidate_share: CandidateDecryption<G>,
213 ciphertext: Ciphertext<G>,
214 index: usize,
215 proof: &LogEqualityProof<G>,
216 ) -> Result<VerifiableDecryption<G>, VerificationError> {
217 let key_share = self.participant_keys[index].as_element();
218 let dh_element = candidate_share.dh_element();
219 let mut transcript = Transcript::new(b"elgamal_decryption_share");
220 self.commit(&mut transcript);
221 transcript.append_u64(b"i", index as u64);
222
223 proof.verify(
224 &PublicKey::from_element(ciphertext.random_element),
225 (key_share, dh_element),
226 &mut transcript,
227 )?;
228 Ok(VerifiableDecryption::from_element(dh_element))
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use rand::thread_rng;
235
236 use super::*;
237 use crate::{
238 group::{ElementOps, Ristretto},
239 sharing::Dealer,
240 };
241
242 #[test]
243 fn restoring_key_set_from_participant_keys_errors() {
244 let mut rng = thread_rng();
245 let params = Params::new(10, 7);
246
247 let dealer = Dealer::<Ristretto>::new(params, &mut rng);
248 let (public_poly, _) = dealer.public_info();
249 let public_poly = PublicPolynomial::<Ristretto>(public_poly);
250 let participant_keys: Vec<PublicKey<Ristretto>> = (1..=params.shares)
251 .map(|i| PublicKey::from_element(public_poly.value_at((i as u64).into())))
252 .collect();
253
254 PublicKeySet::from_participants(params, participant_keys.clone()).unwrap();
256
257 let err =
258 PublicKeySet::from_participants(params, participant_keys[1..].to_vec()).unwrap_err();
259 assert!(matches!(err, Error::ParticipantCountMismatch));
260
261 let mut bogus_keys = participant_keys.clone();
263 bogus_keys.swap(1, 5);
264 let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
265 assert!(matches!(err, Error::MalformedParticipantKeys));
266
267 for i in 0..params.shares {
268 let mut bogus_keys = participant_keys.clone();
269 bogus_keys[i] =
270 PublicKey::from_element(bogus_keys[i].as_element() + Ristretto::generator());
271 let err = PublicKeySet::from_participants(params, bogus_keys).unwrap_err();
272 assert!(matches!(err, Error::MalformedParticipantKeys));
273 }
274 }
275}