1use curve25519_dalek::{
31 constants::RISTRETTO_BASEPOINT_TABLE,
32 ristretto::{CompressedRistretto, RistrettoPoint},
33 scalar::Scalar,
34};
35use rand::Rng;
36use serde::{Deserialize, Serialize};
37use thiserror::Error;
38use zeroize::Zeroize;
39
40#[derive(Error, Debug)]
42pub enum ElGamalError {
43 #[error("Invalid ciphertext")]
44 InvalidCiphertext,
45 #[error("Invalid public key")]
46 InvalidPublicKey,
47 #[error("Decryption failed")]
48 DecryptionFailed,
49 #[error("Value out of range (max 2^32)")]
50 ValueOutOfRange,
51 #[error("Serialization error: {0}")]
52 SerializationError(String),
53}
54
55pub type ElGamalResult<T> = Result<T, ElGamalError>;
56
57#[derive(Clone, Zeroize)]
59#[zeroize(drop)]
60pub struct ElGamalSecretKey {
61 scalar: Scalar,
62}
63
64impl ElGamalSecretKey {
65 pub fn generate() -> Self {
67 let mut rng = rand::thread_rng();
68 let mut bytes = [0u8; 32];
69 rng.fill(&mut bytes);
70 let scalar = Scalar::from_bytes_mod_order(bytes);
71 Self { scalar }
72 }
73
74 pub fn from_bytes(bytes: &[u8; 32]) -> ElGamalResult<Self> {
76 let scalar = Scalar::from_bytes_mod_order(*bytes);
77 Ok(Self { scalar })
78 }
79
80 pub fn to_bytes(&self) -> [u8; 32] {
82 self.scalar.to_bytes()
83 }
84
85 pub fn public_key(&self) -> ElGamalPublicKey {
87 let point = RISTRETTO_BASEPOINT_TABLE * &self.scalar;
88 ElGamalPublicKey { point }
89 }
90}
91
92#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
94pub struct ElGamalPublicKey {
95 point: RistrettoPoint,
96}
97
98impl ElGamalPublicKey {
99 pub fn from_bytes(bytes: &[u8; 32]) -> ElGamalResult<Self> {
101 let compressed =
102 CompressedRistretto::from_slice(bytes).map_err(|_| ElGamalError::InvalidPublicKey)?;
103 let point = compressed
104 .decompress()
105 .ok_or(ElGamalError::InvalidPublicKey)?;
106 Ok(Self { point })
107 }
108
109 pub fn to_bytes(&self) -> [u8; 32] {
111 self.point.compress().to_bytes()
112 }
113}
114
115#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
121pub struct ElGamalCiphertext {
122 c1: RistrettoPoint,
123 c2: RistrettoPoint,
124}
125
126impl ElGamalCiphertext {
127 pub fn from_bytes(bytes: &[u8; 64]) -> ElGamalResult<Self> {
129 let mut c1_bytes = [0u8; 32];
130 let mut c2_bytes = [0u8; 32];
131 c1_bytes.copy_from_slice(&bytes[..32]);
132 c2_bytes.copy_from_slice(&bytes[32..]);
133
134 let compressed_c1 = CompressedRistretto::from_slice(&c1_bytes)
135 .map_err(|_| ElGamalError::InvalidCiphertext)?;
136 let compressed_c2 = CompressedRistretto::from_slice(&c2_bytes)
137 .map_err(|_| ElGamalError::InvalidCiphertext)?;
138
139 let c1 = compressed_c1
140 .decompress()
141 .ok_or(ElGamalError::InvalidCiphertext)?;
142 let c2 = compressed_c2
143 .decompress()
144 .ok_or(ElGamalError::InvalidCiphertext)?;
145
146 Ok(Self { c1, c2 })
147 }
148
149 pub fn to_bytes(&self) -> [u8; 64] {
151 let mut bytes = [0u8; 64];
152 bytes[..32].copy_from_slice(&self.c1.compress().to_bytes());
153 bytes[32..].copy_from_slice(&self.c2.compress().to_bytes());
154 bytes
155 }
156
157 pub fn add(&self, other: &ElGamalCiphertext) -> ElGamalCiphertext {
159 ElGamalCiphertext {
160 c1: self.c1 + other.c1,
161 c2: self.c2 + other.c2,
162 }
163 }
164
165 pub fn mul_scalar(&self, scalar: u64) -> ElGamalCiphertext {
167 let s = Scalar::from(scalar);
168 ElGamalCiphertext {
169 c1: self.c1 * s,
170 c2: self.c2 * s,
171 }
172 }
173
174 pub fn rerandomize(&self, public_key: &ElGamalPublicKey) -> ElGamalCiphertext {
177 let mut rng = rand::thread_rng();
178 let mut r_bytes = [0u8; 32];
179 rng.fill(&mut r_bytes);
180 let r = Scalar::from_bytes_mod_order(r_bytes);
181
182 let delta_c1 = RISTRETTO_BASEPOINT_TABLE * &r;
184 let delta_c2 = public_key.point * r;
185
186 ElGamalCiphertext {
187 c1: self.c1 + delta_c1,
188 c2: self.c2 + delta_c2,
189 }
190 }
191}
192
193pub struct ElGamalKeypair {
195 secret_key: ElGamalSecretKey,
196 public_key: ElGamalPublicKey,
197}
198
199impl ElGamalKeypair {
200 pub fn generate() -> Self {
202 let secret_key = ElGamalSecretKey::generate();
203 let public_key = secret_key.public_key();
204 Self {
205 secret_key,
206 public_key,
207 }
208 }
209
210 pub fn from_secret_key(secret_key: ElGamalSecretKey) -> Self {
212 let public_key = secret_key.public_key();
213 Self {
214 secret_key,
215 public_key,
216 }
217 }
218
219 pub fn public_key(&self) -> ElGamalPublicKey {
221 self.public_key
222 }
223
224 pub fn secret_key(&self) -> &ElGamalSecretKey {
226 &self.secret_key
227 }
228
229 pub fn encrypt(&self, message: u64) -> ElGamalCiphertext {
234 encrypt(&self.public_key, message)
235 }
236
237 pub fn decrypt(&self, ciphertext: &ElGamalCiphertext) -> ElGamalResult<u64> {
242 decrypt(&self.secret_key, ciphertext)
243 }
244}
245
246pub fn encrypt(public_key: &ElGamalPublicKey, message: u64) -> ElGamalCiphertext {
248 let mut rng = rand::thread_rng();
250 let mut r_bytes = [0u8; 32];
251 rng.fill(&mut r_bytes);
252 let r = Scalar::from_bytes_mod_order(r_bytes);
253
254 let m_scalar = Scalar::from(message);
256 let m_point = RISTRETTO_BASEPOINT_TABLE * &m_scalar;
257
258 let c1 = RISTRETTO_BASEPOINT_TABLE * &r;
260
261 let c2 = m_point + (public_key.point * r);
263
264 ElGamalCiphertext { c1, c2 }
265}
266
267pub fn decrypt(
272 secret_key: &ElGamalSecretKey,
273 ciphertext: &ElGamalCiphertext,
274) -> ElGamalResult<u64> {
275 let m_point = ciphertext.c2 - (secret_key.scalar * ciphertext.c1);
277
278 solve_discrete_log(&m_point)
281}
282
283fn solve_discrete_log(point: &RistrettoPoint) -> ElGamalResult<u64> {
287 const MAX_SEARCH: u64 = 1 << 20; const BATCH_SIZE: u64 = 1 << 10; let mut baby_steps = std::collections::HashMap::new();
293 let mut current = RistrettoPoint::default(); let generator = RISTRETTO_BASEPOINT_TABLE * &Scalar::ONE;
295
296 for i in 0..BATCH_SIZE {
297 baby_steps.insert(current.compress().to_bytes(), i);
298 current += generator;
299 }
300
301 let giant_step = generator * Scalar::from(BATCH_SIZE);
303 let mut current = *point;
304
305 for j in 0..(MAX_SEARCH / BATCH_SIZE) {
306 if let Some(&i) = baby_steps.get(¤t.compress().to_bytes()) {
307 return Ok(j * BATCH_SIZE + i);
308 }
309 current -= giant_step;
310 }
311
312 Err(ElGamalError::DecryptionFailed)
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_keypair_generation() {
321 let keypair = ElGamalKeypair::generate();
322 let pk = keypair.public_key();
323
324 let pk_bytes = pk.to_bytes();
326 let pk2 = ElGamalPublicKey::from_bytes(&pk_bytes).unwrap();
327 assert_eq!(pk, pk2);
328 }
329
330 #[test]
331 fn test_encrypt_decrypt() {
332 let keypair = ElGamalKeypair::generate();
333 let message = 42u64;
334
335 let ciphertext = keypair.encrypt(message);
336 let decrypted = keypair.decrypt(&ciphertext).unwrap();
337
338 assert_eq!(message, decrypted);
339 }
340
341 #[test]
342 fn test_homomorphic_addition() {
343 let keypair = ElGamalKeypair::generate();
344 let msg1 = 100u64;
345 let msg2 = 200u64;
346
347 let ct1 = keypair.encrypt(msg1);
348 let ct2 = keypair.encrypt(msg2);
349
350 let ct_sum = ct1.add(&ct2);
352
353 let sum = keypair.decrypt(&ct_sum).unwrap();
355 assert_eq!(sum, msg1 + msg2);
356 }
357
358 #[test]
359 fn test_scalar_multiplication() {
360 let keypair = ElGamalKeypair::generate();
361 let msg = 50u64;
362 let k = 3u64;
363
364 let ct = keypair.encrypt(msg);
365 let ct_mult = ct.mul_scalar(k);
366
367 let result = keypair.decrypt(&ct_mult).unwrap();
368 assert_eq!(result, msg * k);
369 }
370
371 #[test]
372 fn test_rerandomization() {
373 let keypair = ElGamalKeypair::generate();
374 let message = 123u64;
375
376 let ct1 = keypair.encrypt(message);
377 let ct2 = ct1.rerandomize(&keypair.public_key());
378
379 assert_ne!(ct1, ct2);
381
382 assert_eq!(keypair.decrypt(&ct1).unwrap(), message);
384 assert_eq!(keypair.decrypt(&ct2).unwrap(), message);
385 }
386
387 #[test]
388 fn test_ciphertext_serialization() {
389 let keypair = ElGamalKeypair::generate();
390 let message = 777u64;
391
392 let ct = keypair.encrypt(message);
393 let ct_bytes = ct.to_bytes();
394 let ct2 = ElGamalCiphertext::from_bytes(&ct_bytes).unwrap();
395
396 assert_eq!(ct, ct2);
397 assert_eq!(keypair.decrypt(&ct2).unwrap(), message);
398 }
399
400 #[test]
401 fn test_zero_message() {
402 let keypair = ElGamalKeypair::generate();
403 let message = 0u64;
404
405 let ct = keypair.encrypt(message);
406 let decrypted = keypair.decrypt(&ct).unwrap();
407
408 assert_eq!(message, decrypted);
409 }
410
411 #[test]
412 fn test_large_message() {
413 let keypair = ElGamalKeypair::generate();
414 let message = 10000u64;
415
416 let ct = keypair.encrypt(message);
417 let decrypted = keypair.decrypt(&ct).unwrap();
418
419 assert_eq!(message, decrypted);
420 }
421
422 #[test]
423 fn test_multiple_additions() {
424 let keypair = ElGamalKeypair::generate();
425 let values = vec![10u64, 20, 30, 40, 50];
426 let expected_sum: u64 = values.iter().sum();
427
428 let mut ct_sum = keypair.encrypt(0);
429 for &value in &values {
430 let ct = keypair.encrypt(value);
431 ct_sum = ct_sum.add(&ct);
432 }
433
434 let result = keypair.decrypt(&ct_sum).unwrap();
435 assert_eq!(result, expected_sum);
436 }
437
438 #[test]
439 fn test_secret_key_serialization() {
440 let sk = ElGamalSecretKey::generate();
441 let sk_bytes = sk.to_bytes();
442 let sk2 = ElGamalSecretKey::from_bytes(&sk_bytes).unwrap();
443
444 assert_eq!(sk.to_bytes(), sk2.to_bytes());
445 assert_eq!(sk.public_key().to_bytes(), sk2.public_key().to_bytes());
446 }
447
448 #[test]
449 fn test_deterministic_public_key() {
450 let sk_bytes = [42u8; 32];
451 let sk1 = ElGamalSecretKey::from_bytes(&sk_bytes).unwrap();
452 let sk2 = ElGamalSecretKey::from_bytes(&sk_bytes).unwrap();
453
454 assert_eq!(sk1.public_key(), sk2.public_key());
455 }
456
457 #[test]
458 fn test_encryption_randomness() {
459 let keypair = ElGamalKeypair::generate();
460 let message = 100u64;
461
462 let ct1 = keypair.encrypt(message);
463 let ct2 = keypair.encrypt(message);
464
465 assert_ne!(ct1, ct2);
467
468 assert_eq!(keypair.decrypt(&ct1).unwrap(), message);
470 assert_eq!(keypair.decrypt(&ct2).unwrap(), message);
471 }
472}