chie_crypto/
elgamal.rs

1//! ElGamal encryption for additively homomorphic public key encryption.
2//!
3//! ElGamal encryption provides:
4//! - Additively homomorphic encryption: E(m₁) + E(m₂) = E(m₁ + m₂)
5//! - Re-randomization support for ciphertext unlinkability
6//! - Public key encryption with semantic security
7//! - Useful for privacy-preserving aggregations in CHIE protocol
8//!
9//! # Example
10//! ```
11//! use chie_crypto::elgamal::{ElGamalKeypair, ElGamalCiphertext};
12//!
13//! // Generate a keypair
14//! let keypair = ElGamalKeypair::generate();
15//!
16//! // Encrypt messages
17//! let msg1 = 100u64;
18//! let msg2 = 200u64;
19//! let ct1 = keypair.encrypt(msg1);
20//! let ct2 = keypair.encrypt(msg2);
21//!
22//! // Homomorphic addition
23//! let ct_sum = ct1.add(&ct2);
24//!
25//! // Decrypt the sum
26//! let sum = keypair.decrypt(&ct_sum).unwrap();
27//! assert_eq!(sum, msg1 + msg2);
28//! ```
29
30use curve25519_dalek::{
31    constants::RISTRETTO_BASEPOINT_TABLE,
32    ristretto::{CompressedRistretto, RistrettoPoint},
33    scalar::Scalar,
34};
35use rand::Rng;
36use serde::{Deserialize, Serialize};
37use thiserror::Error;
38use zeroize::Zeroize;
39
40/// ElGamal encryption error types
41#[derive(Error, Debug)]
42pub enum ElGamalError {
43    #[error("Invalid ciphertext")]
44    InvalidCiphertext,
45    #[error("Invalid public key")]
46    InvalidPublicKey,
47    #[error("Decryption failed")]
48    DecryptionFailed,
49    #[error("Value out of range (max 2^32)")]
50    ValueOutOfRange,
51    #[error("Serialization error: {0}")]
52    SerializationError(String),
53}
54
55pub type ElGamalResult<T> = Result<T, ElGamalError>;
56
57/// ElGamal secret key (scalar in the Ristretto group)
58#[derive(Clone, Zeroize)]
59#[zeroize(drop)]
60pub struct ElGamalSecretKey {
61    scalar: Scalar,
62}
63
64impl ElGamalSecretKey {
65    /// Generate a random ElGamal secret key
66    pub fn generate() -> Self {
67        let mut rng = rand::thread_rng();
68        let mut bytes = [0u8; 32];
69        rng.fill(&mut bytes);
70        let scalar = Scalar::from_bytes_mod_order(bytes);
71        Self { scalar }
72    }
73
74    /// Create an ElGamal secret key from bytes
75    pub fn from_bytes(bytes: &[u8; 32]) -> ElGamalResult<Self> {
76        let scalar = Scalar::from_bytes_mod_order(*bytes);
77        Ok(Self { scalar })
78    }
79
80    /// Export secret key to bytes
81    pub fn to_bytes(&self) -> [u8; 32] {
82        self.scalar.to_bytes()
83    }
84
85    /// Derive public key from secret key
86    pub fn public_key(&self) -> ElGamalPublicKey {
87        let point = RISTRETTO_BASEPOINT_TABLE * &self.scalar;
88        ElGamalPublicKey { point }
89    }
90}
91
92/// ElGamal public key (point in the Ristretto group)
93#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
94pub struct ElGamalPublicKey {
95    point: RistrettoPoint,
96}
97
98impl ElGamalPublicKey {
99    /// Create an ElGamal public key from compressed bytes
100    pub fn from_bytes(bytes: &[u8; 32]) -> ElGamalResult<Self> {
101        let compressed =
102            CompressedRistretto::from_slice(bytes).map_err(|_| ElGamalError::InvalidPublicKey)?;
103        let point = compressed
104            .decompress()
105            .ok_or(ElGamalError::InvalidPublicKey)?;
106        Ok(Self { point })
107    }
108
109    /// Export public key to compressed bytes
110    pub fn to_bytes(&self) -> [u8; 32] {
111        self.point.compress().to_bytes()
112    }
113}
114
115/// ElGamal ciphertext (c₁, c₂) where:
116/// - c₁ = r*G (ephemeral public key)
117/// - c₂ = m*G + r*H (encrypted message)
118///
119/// where H is the recipient's public key
120#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
121pub struct ElGamalCiphertext {
122    c1: RistrettoPoint,
123    c2: RistrettoPoint,
124}
125
126impl ElGamalCiphertext {
127    /// Create an ElGamal ciphertext from compressed bytes (64 bytes: 32 for c1 + 32 for c2)
128    pub fn from_bytes(bytes: &[u8; 64]) -> ElGamalResult<Self> {
129        let mut c1_bytes = [0u8; 32];
130        let mut c2_bytes = [0u8; 32];
131        c1_bytes.copy_from_slice(&bytes[..32]);
132        c2_bytes.copy_from_slice(&bytes[32..]);
133
134        let compressed_c1 = CompressedRistretto::from_slice(&c1_bytes)
135            .map_err(|_| ElGamalError::InvalidCiphertext)?;
136        let compressed_c2 = CompressedRistretto::from_slice(&c2_bytes)
137            .map_err(|_| ElGamalError::InvalidCiphertext)?;
138
139        let c1 = compressed_c1
140            .decompress()
141            .ok_or(ElGamalError::InvalidCiphertext)?;
142        let c2 = compressed_c2
143            .decompress()
144            .ok_or(ElGamalError::InvalidCiphertext)?;
145
146        Ok(Self { c1, c2 })
147    }
148
149    /// Export ciphertext to bytes
150    pub fn to_bytes(&self) -> [u8; 64] {
151        let mut bytes = [0u8; 64];
152        bytes[..32].copy_from_slice(&self.c1.compress().to_bytes());
153        bytes[32..].copy_from_slice(&self.c2.compress().to_bytes());
154        bytes
155    }
156
157    /// Homomorphic addition: E(m₁) + E(m₂) = E(m₁ + m₂)
158    pub fn add(&self, other: &ElGamalCiphertext) -> ElGamalCiphertext {
159        ElGamalCiphertext {
160            c1: self.c1 + other.c1,
161            c2: self.c2 + other.c2,
162        }
163    }
164
165    /// Scalar multiplication: k * E(m) = E(k * m)
166    pub fn mul_scalar(&self, scalar: u64) -> ElGamalCiphertext {
167        let s = Scalar::from(scalar);
168        ElGamalCiphertext {
169            c1: self.c1 * s,
170            c2: self.c2 * s,
171        }
172    }
173
174    /// Re-randomize the ciphertext for unlinkability
175    /// Returns a new ciphertext encrypting the same message but unlinkable to the original
176    pub fn rerandomize(&self, public_key: &ElGamalPublicKey) -> ElGamalCiphertext {
177        let mut rng = rand::thread_rng();
178        let mut r_bytes = [0u8; 32];
179        rng.fill(&mut r_bytes);
180        let r = Scalar::from_bytes_mod_order(r_bytes);
181
182        // Add encryption of zero: (r*G, r*H)
183        let delta_c1 = RISTRETTO_BASEPOINT_TABLE * &r;
184        let delta_c2 = public_key.point * r;
185
186        ElGamalCiphertext {
187            c1: self.c1 + delta_c1,
188            c2: self.c2 + delta_c2,
189        }
190    }
191}
192
193/// ElGamal keypair (secret key + public key)
194pub struct ElGamalKeypair {
195    secret_key: ElGamalSecretKey,
196    public_key: ElGamalPublicKey,
197}
198
199impl ElGamalKeypair {
200    /// Generate a random ElGamal keypair
201    pub fn generate() -> Self {
202        let secret_key = ElGamalSecretKey::generate();
203        let public_key = secret_key.public_key();
204        Self {
205            secret_key,
206            public_key,
207        }
208    }
209
210    /// Create a keypair from a secret key
211    pub fn from_secret_key(secret_key: ElGamalSecretKey) -> Self {
212        let public_key = secret_key.public_key();
213        Self {
214            secret_key,
215            public_key,
216        }
217    }
218
219    /// Get the public key
220    pub fn public_key(&self) -> ElGamalPublicKey {
221        self.public_key
222    }
223
224    /// Get a reference to the secret key
225    pub fn secret_key(&self) -> &ElGamalSecretKey {
226        &self.secret_key
227    }
228
229    /// Encrypt a message (u64 value)
230    ///
231    /// The message is encoded as m*G (point on the curve)
232    /// Ciphertext: (c₁, c₂) = (r*G, m*G + r*H)
233    pub fn encrypt(&self, message: u64) -> ElGamalCiphertext {
234        encrypt(&self.public_key, message)
235    }
236
237    /// Decrypt a ciphertext to recover the original message
238    ///
239    /// Decryption: m*G = c₂ - x*c₁
240    /// Then solve discrete log to get m
241    pub fn decrypt(&self, ciphertext: &ElGamalCiphertext) -> ElGamalResult<u64> {
242        decrypt(&self.secret_key, ciphertext)
243    }
244}
245
246/// Encrypt a message using ElGamal encryption
247pub fn encrypt(public_key: &ElGamalPublicKey, message: u64) -> ElGamalCiphertext {
248    // Generate random ephemeral key
249    let mut rng = rand::thread_rng();
250    let mut r_bytes = [0u8; 32];
251    rng.fill(&mut r_bytes);
252    let r = Scalar::from_bytes_mod_order(r_bytes);
253
254    // Encode message as point: m*G
255    let m_scalar = Scalar::from(message);
256    let m_point = RISTRETTO_BASEPOINT_TABLE * &m_scalar;
257
258    // c₁ = r*G
259    let c1 = RISTRETTO_BASEPOINT_TABLE * &r;
260
261    // c₂ = m*G + r*H
262    let c2 = m_point + (public_key.point * r);
263
264    ElGamalCiphertext { c1, c2 }
265}
266
267/// Decrypt an ElGamal ciphertext
268///
269/// Uses baby-step giant-step algorithm for discrete log
270/// Works for small messages (up to 2^32)
271pub fn decrypt(
272    secret_key: &ElGamalSecretKey,
273    ciphertext: &ElGamalCiphertext,
274) -> ElGamalResult<u64> {
275    // Compute m*G = c₂ - x*c₁
276    let m_point = ciphertext.c2 - (secret_key.scalar * ciphertext.c1);
277
278    // Solve discrete log to get m
279    // For small values, we can use brute force or baby-step giant-step
280    solve_discrete_log(&m_point)
281}
282
283/// Solve discrete log for small values using baby-step giant-step algorithm
284///
285/// This finds m such that m*G = P for small m (up to 2^20 ~ 1 million)
286fn solve_discrete_log(point: &RistrettoPoint) -> ElGamalResult<u64> {
287    const MAX_SEARCH: u64 = 1 << 20; // Search up to 2^20 ~ 1 million
288    const BATCH_SIZE: u64 = 1 << 10; // Baby step size
289
290    // Baby-step giant-step algorithm
291    // Precompute baby steps: i*G for i = 0..BATCH_SIZE
292    let mut baby_steps = std::collections::HashMap::new();
293    let mut current = RistrettoPoint::default(); // 0*G = identity
294    let generator = RISTRETTO_BASEPOINT_TABLE * &Scalar::ONE;
295
296    for i in 0..BATCH_SIZE {
297        baby_steps.insert(current.compress().to_bytes(), i);
298        current += generator;
299    }
300
301    // Giant steps: check if P - j*BATCH_SIZE*G is in baby steps
302    let giant_step = generator * Scalar::from(BATCH_SIZE);
303    let mut current = *point;
304
305    for j in 0..(MAX_SEARCH / BATCH_SIZE) {
306        if let Some(&i) = baby_steps.get(&current.compress().to_bytes()) {
307            return Ok(j * BATCH_SIZE + i);
308        }
309        current -= giant_step;
310    }
311
312    Err(ElGamalError::DecryptionFailed)
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_keypair_generation() {
321        let keypair = ElGamalKeypair::generate();
322        let pk = keypair.public_key();
323
324        // Verify public key can be serialized and deserialized
325        let pk_bytes = pk.to_bytes();
326        let pk2 = ElGamalPublicKey::from_bytes(&pk_bytes).unwrap();
327        assert_eq!(pk, pk2);
328    }
329
330    #[test]
331    fn test_encrypt_decrypt() {
332        let keypair = ElGamalKeypair::generate();
333        let message = 42u64;
334
335        let ciphertext = keypair.encrypt(message);
336        let decrypted = keypair.decrypt(&ciphertext).unwrap();
337
338        assert_eq!(message, decrypted);
339    }
340
341    #[test]
342    fn test_homomorphic_addition() {
343        let keypair = ElGamalKeypair::generate();
344        let msg1 = 100u64;
345        let msg2 = 200u64;
346
347        let ct1 = keypair.encrypt(msg1);
348        let ct2 = keypair.encrypt(msg2);
349
350        // Homomorphic addition
351        let ct_sum = ct1.add(&ct2);
352
353        // Decrypt the sum
354        let sum = keypair.decrypt(&ct_sum).unwrap();
355        assert_eq!(sum, msg1 + msg2);
356    }
357
358    #[test]
359    fn test_scalar_multiplication() {
360        let keypair = ElGamalKeypair::generate();
361        let msg = 50u64;
362        let k = 3u64;
363
364        let ct = keypair.encrypt(msg);
365        let ct_mult = ct.mul_scalar(k);
366
367        let result = keypair.decrypt(&ct_mult).unwrap();
368        assert_eq!(result, msg * k);
369    }
370
371    #[test]
372    fn test_rerandomization() {
373        let keypair = ElGamalKeypair::generate();
374        let message = 123u64;
375
376        let ct1 = keypair.encrypt(message);
377        let ct2 = ct1.rerandomize(&keypair.public_key());
378
379        // Different ciphertexts
380        assert_ne!(ct1, ct2);
381
382        // Same plaintext
383        assert_eq!(keypair.decrypt(&ct1).unwrap(), message);
384        assert_eq!(keypair.decrypt(&ct2).unwrap(), message);
385    }
386
387    #[test]
388    fn test_ciphertext_serialization() {
389        let keypair = ElGamalKeypair::generate();
390        let message = 777u64;
391
392        let ct = keypair.encrypt(message);
393        let ct_bytes = ct.to_bytes();
394        let ct2 = ElGamalCiphertext::from_bytes(&ct_bytes).unwrap();
395
396        assert_eq!(ct, ct2);
397        assert_eq!(keypair.decrypt(&ct2).unwrap(), message);
398    }
399
400    #[test]
401    fn test_zero_message() {
402        let keypair = ElGamalKeypair::generate();
403        let message = 0u64;
404
405        let ct = keypair.encrypt(message);
406        let decrypted = keypair.decrypt(&ct).unwrap();
407
408        assert_eq!(message, decrypted);
409    }
410
411    #[test]
412    fn test_large_message() {
413        let keypair = ElGamalKeypair::generate();
414        let message = 10000u64;
415
416        let ct = keypair.encrypt(message);
417        let decrypted = keypair.decrypt(&ct).unwrap();
418
419        assert_eq!(message, decrypted);
420    }
421
422    #[test]
423    fn test_multiple_additions() {
424        let keypair = ElGamalKeypair::generate();
425        let values = vec![10u64, 20, 30, 40, 50];
426        let expected_sum: u64 = values.iter().sum();
427
428        let mut ct_sum = keypair.encrypt(0);
429        for &value in &values {
430            let ct = keypair.encrypt(value);
431            ct_sum = ct_sum.add(&ct);
432        }
433
434        let result = keypair.decrypt(&ct_sum).unwrap();
435        assert_eq!(result, expected_sum);
436    }
437
438    #[test]
439    fn test_secret_key_serialization() {
440        let sk = ElGamalSecretKey::generate();
441        let sk_bytes = sk.to_bytes();
442        let sk2 = ElGamalSecretKey::from_bytes(&sk_bytes).unwrap();
443
444        assert_eq!(sk.to_bytes(), sk2.to_bytes());
445        assert_eq!(sk.public_key().to_bytes(), sk2.public_key().to_bytes());
446    }
447
448    #[test]
449    fn test_deterministic_public_key() {
450        let sk_bytes = [42u8; 32];
451        let sk1 = ElGamalSecretKey::from_bytes(&sk_bytes).unwrap();
452        let sk2 = ElGamalSecretKey::from_bytes(&sk_bytes).unwrap();
453
454        assert_eq!(sk1.public_key(), sk2.public_key());
455    }
456
457    #[test]
458    fn test_encryption_randomness() {
459        let keypair = ElGamalKeypair::generate();
460        let message = 100u64;
461
462        let ct1 = keypair.encrypt(message);
463        let ct2 = keypair.encrypt(message);
464
465        // Different ciphertexts due to random ephemeral key
466        assert_ne!(ct1, ct2);
467
468        // Same plaintext
469        assert_eq!(keypair.decrypt(&ct1).unwrap(), message);
470        assert_eq!(keypair.decrypt(&ct2).unwrap(), message);
471    }
472}