chie_crypto/
functional_encryption.rs

1//! Functional Encryption (FE) primitives
2//!
3//! Functional encryption allows a client with a secret key SK_f to learn f(x) from an
4//! encryption of x, without learning anything else about x. This is useful for
5//! privacy-preserving computation where you want to compute on encrypted data.
6//!
7//! This module implements Inner Product Functional Encryption (IPFE), one of the most
8//! practical forms of FE, allowing computation of inner products over encrypted vectors.
9//!
10//! # Example
11//!
12//! ```
13//! use chie_crypto::functional_encryption::*;
14//!
15//! // Setup master keys for vectors of length 3
16//! let (msk, mpk) = ipfe_setup(3);
17//!
18//! // Encrypt a vector [5, 10, 15]
19//! let plaintext = vec![5, 10, 15];
20//! let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
21//!
22//! // Generate a functional key for computing inner product with [1, 2, 3]
23//! let func_vector = vec![1, 2, 3];
24//! let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
25//!
26//! // Decrypt to get the inner product: 5*1 + 10*2 + 15*3 = 70
27//! let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
28//! assert_eq!(result, 70);
29//! ```
30
31use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
32use curve25519_dalek::scalar::Scalar;
33use curve25519_dalek::traits::Identity;
34use serde::{Deserialize, Serialize};
35use sha2::{Digest, Sha256};
36use thiserror::Error;
37
38/// Functional encryption error types
39#[derive(Error, Debug)]
40pub enum FunctionalEncryptionError {
41    #[error("Invalid input: {0}")]
42    InvalidInput(String),
43    #[error("Decryption failed: {0}")]
44    DecryptionFailed(String),
45    #[error("Serialization error: {0}")]
46    SerializationError(String),
47}
48
49/// Result type for functional encryption operations
50pub type FunctionalEncryptionResult<T> = Result<T, FunctionalEncryptionError>;
51
52/// Master secret key for IPFE
53#[derive(Clone, Serialize, Deserialize)]
54pub struct IpfeMasterSecretKey {
55    /// Secret key scalars (one per dimension)
56    secret_scalars: Vec<Scalar>,
57}
58
59/// Master public key for IPFE
60#[derive(Clone, Serialize, Deserialize)]
61pub struct IpfeMasterPublicKey {
62    /// Public key points (one per dimension)
63    #[serde(with = "serde_point_vec")]
64    public_points: Vec<RistrettoPoint>,
65    /// Base generator
66    #[serde(with = "serde_point")]
67    generator: RistrettoPoint,
68}
69
70/// Functional secret key for computing inner products
71#[derive(Clone, Serialize, Deserialize)]
72pub struct IpfeFunctionalKey {
73    /// Functional key scalar (derived from master secret and function vector)
74    functional_scalar: Scalar,
75    /// Function vector (needed for decryption)
76    func_vector: Vec<i64>,
77}
78
79/// Ciphertext for IPFE
80#[derive(Clone, Serialize, Deserialize)]
81pub struct IpfeCiphertext {
82    /// c_0 = g^r
83    #[serde(with = "serde_point")]
84    c0: RistrettoPoint,
85    /// c_i = h_i^r * g^{x_i}
86    #[serde(with = "serde_point_vec")]
87    encrypted_points: Vec<RistrettoPoint>,
88}
89
90// Serde helpers for RistrettoPoint
91mod serde_point {
92    use super::*;
93    use serde::{Deserializer, Serializer};
94
95    pub fn serialize<S>(point: &RistrettoPoint, serializer: S) -> Result<S::Ok, S::Error>
96    where
97        S: Serializer,
98    {
99        let bytes = point.compress().to_bytes();
100        serializer.serialize_bytes(&bytes)
101    }
102
103    pub fn deserialize<'de, D>(deserializer: D) -> Result<RistrettoPoint, D::Error>
104    where
105        D: Deserializer<'de>,
106    {
107        let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
108        if bytes.len() != 32 {
109            return Err(serde::de::Error::custom("invalid point length"));
110        }
111        let mut arr = [0u8; 32];
112        arr.copy_from_slice(&bytes);
113        CompressedRistretto(arr)
114            .decompress()
115            .ok_or_else(|| serde::de::Error::custom("invalid point"))
116    }
117}
118
119mod serde_point_vec {
120    use super::*;
121    use serde::{Deserializer, Serializer};
122
123    pub fn serialize<S>(points: &[RistrettoPoint], serializer: S) -> Result<S::Ok, S::Error>
124    where
125        S: Serializer,
126    {
127        let bytes: Vec<Vec<u8>> = points
128            .iter()
129            .map(|p| p.compress().to_bytes().to_vec())
130            .collect();
131        bytes.serialize(serializer)
132    }
133
134    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<RistrettoPoint>, D::Error>
135    where
136        D: Deserializer<'de>,
137    {
138        let bytes_vec: Vec<Vec<u8>> = Deserialize::deserialize(deserializer)?;
139        bytes_vec
140            .into_iter()
141            .map(|bytes| {
142                if bytes.len() != 32 {
143                    return Err(serde::de::Error::custom("invalid point length"));
144                }
145                let mut arr = [0u8; 32];
146                arr.copy_from_slice(&bytes);
147                CompressedRistretto(arr)
148                    .decompress()
149                    .ok_or_else(|| serde::de::Error::custom("invalid point"))
150            })
151            .collect()
152    }
153}
154
155/// Generate master secret and public keys for IPFE
156///
157/// # Arguments
158/// * `dimension` - The dimension of vectors to be encrypted
159///
160/// # Returns
161/// A tuple of (master_secret_key, master_public_key)
162pub fn ipfe_setup(dimension: usize) -> (IpfeMasterSecretKey, IpfeMasterPublicKey) {
163    let generator = curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
164
165    let mut secret_scalars = Vec::with_capacity(dimension);
166    let mut public_points = Vec::with_capacity(dimension);
167
168    for i in 0..dimension {
169        // Generate secret scalar from hash
170        let mut hasher = Sha256::new();
171        hasher.update(b"ipfe_master_secret");
172        hasher.update(i.to_le_bytes());
173        hasher.update(rand::random::<[u8; 32]>());
174        let hash = hasher.finalize();
175        let scalar = Scalar::from_bytes_mod_order(hash.into());
176
177        // Public key is g^s
178        let public_point = generator * scalar;
179
180        secret_scalars.push(scalar);
181        public_points.push(public_point);
182    }
183
184    let msk = IpfeMasterSecretKey { secret_scalars };
185    let mpk = IpfeMasterPublicKey {
186        public_points,
187        generator,
188    };
189
190    (msk, mpk)
191}
192
193/// Encrypt a plaintext vector using the master public key
194///
195/// # Arguments
196/// * `mpk` - Master public key
197/// * `plaintext` - Vector of integers to encrypt
198///
199/// # Returns
200/// Encrypted ciphertext
201pub fn ipfe_encrypt(
202    mpk: &IpfeMasterPublicKey,
203    plaintext: &[i64],
204) -> FunctionalEncryptionResult<IpfeCiphertext> {
205    if plaintext.len() != mpk.public_points.len() {
206        return Err(FunctionalEncryptionError::InvalidInput(
207            "plaintext dimension mismatch".to_string(),
208        ));
209    }
210
211    // Generate random scalar r
212    let r = Scalar::from_bytes_mod_order(rand::random::<[u8; 32]>());
213
214    // c_0 = g^r
215    let c0 = mpk.generator * r;
216
217    let mut encrypted_points = Vec::with_capacity(plaintext.len());
218
219    for (i, &value) in plaintext.iter().enumerate() {
220        // Convert value to scalar
221        let value_scalar = Scalar::from(value.unsigned_abs());
222        let value_scalar = if value < 0 {
223            -value_scalar
224        } else {
225            value_scalar
226        };
227
228        // c_i = h_i^r * g^{x_i}
229        let encrypted = (mpk.public_points[i] * r) + (mpk.generator * value_scalar);
230        encrypted_points.push(encrypted);
231    }
232
233    Ok(IpfeCiphertext {
234        c0,
235        encrypted_points,
236    })
237}
238
239/// Generate a functional secret key for computing inner product with a given vector
240///
241/// # Arguments
242/// * `msk` - Master secret key
243/// * `func_vector` - Vector to compute inner product with
244///
245/// # Returns
246/// Functional secret key
247pub fn ipfe_keygen(
248    msk: &IpfeMasterSecretKey,
249    func_vector: &[i64],
250) -> FunctionalEncryptionResult<IpfeFunctionalKey> {
251    if func_vector.len() != msk.secret_scalars.len() {
252        return Err(FunctionalEncryptionError::InvalidInput(
253            "function vector dimension mismatch".to_string(),
254        ));
255    }
256
257    // Compute functional key: sum of (y_i * s_i)
258    let mut functional_scalar = Scalar::ZERO;
259
260    for (i, &value) in func_vector.iter().enumerate() {
261        let value_scalar = Scalar::from(value.unsigned_abs());
262        let value_scalar = if value < 0 {
263            -value_scalar
264        } else {
265            value_scalar
266        };
267
268        functional_scalar += value_scalar * msk.secret_scalars[i];
269    }
270
271    Ok(IpfeFunctionalKey {
272        functional_scalar,
273        func_vector: func_vector.to_vec(),
274    })
275}
276
277/// Decrypt a ciphertext using a functional key to compute the inner product
278///
279/// # Arguments
280/// * `func_key` - Functional secret key
281/// * `ciphertext` - Encrypted vector
282///
283/// # Returns
284/// The inner product result
285pub fn ipfe_decrypt(
286    func_key: &IpfeFunctionalKey,
287    ciphertext: &IpfeCiphertext,
288) -> FunctionalEncryptionResult<i64> {
289    // Check dimension match
290    if func_key.func_vector.len() != ciphertext.encrypted_points.len() {
291        return Err(FunctionalEncryptionError::InvalidInput(
292            "function vector and ciphertext dimension mismatch".to_string(),
293        ));
294    }
295
296    // Compute: sum(y_i * c_i) - sk_y * c_0
297    // This gives: sum(y_i * (h_i^r * g^{x_i})) - (sum(y_i * s_i)) * g^r
298    //           = sum(y_i * g^{s_i * r} + y_i * g^{x_i}) - g^{r * sum(y_i * s_i)}
299    //           = sum(y_i * g^{s_i * r}) + sum(y_i * g^{x_i}) - g^{r * sum(y_i * s_i)}
300    //           = g^{r * sum(y_i * s_i)} + g^{sum(y_i * x_i)} - g^{r * sum(y_i * s_i)}
301    //           = g^{<x,y>}
302
303    let generator = curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
304    let mut result_point = RistrettoPoint::identity();
305
306    // Compute weighted sum: sum(y_i * c_i)
307    for (i, &y_i) in func_key.func_vector.iter().enumerate() {
308        let y_scalar = Scalar::from(y_i.unsigned_abs());
309        let y_scalar = if y_i < 0 { -y_scalar } else { y_scalar };
310
311        result_point += ciphertext.encrypted_points[i] * y_scalar;
312    }
313
314    // Subtract: sk_y * c_0
315    result_point -= ciphertext.c0 * func_key.functional_scalar;
316
317    // Now result_point should be g^{<x,y>}
318    // Discrete log solver for small values (brute force)
319    // This works for results in range [-10000, 10000]
320    for i in 0..=10000 {
321        if result_point == generator * Scalar::from(i as u64) {
322            return Ok(i);
323        }
324        if result_point == generator * (-Scalar::from(i as u64)) {
325            return Ok(-i);
326        }
327    }
328
329    Err(FunctionalEncryptionError::DecryptionFailed(
330        "discrete log too large".to_string(),
331    ))
332}
333
334/// Multi-client IPFE setup for privacy-preserving aggregation
335pub struct MultiClientIpfe {
336    dimension: usize,
337    master_keys: Vec<(IpfeMasterSecretKey, IpfeMasterPublicKey)>,
338}
339
340impl MultiClientIpfe {
341    /// Setup multi-client IPFE for n clients
342    pub fn setup(num_clients: usize, dimension: usize) -> Self {
343        let mut master_keys = Vec::with_capacity(num_clients);
344
345        for _ in 0..num_clients {
346            master_keys.push(ipfe_setup(dimension));
347        }
348
349        Self {
350            dimension,
351            master_keys,
352        }
353    }
354
355    /// Get public key for client i
356    pub fn get_public_key(&self, client_id: usize) -> Option<&IpfeMasterPublicKey> {
357        self.master_keys.get(client_id).map(|(_, mpk)| mpk)
358    }
359
360    /// Generate functional key for computing sum of inner products
361    pub fn keygen(
362        &self,
363        func_vector: &[i64],
364    ) -> FunctionalEncryptionResult<Vec<IpfeFunctionalKey>> {
365        if func_vector.len() != self.dimension {
366            return Err(FunctionalEncryptionError::InvalidInput(
367                "function vector dimension mismatch".to_string(),
368            ));
369        }
370
371        let mut func_keys = Vec::with_capacity(self.master_keys.len());
372
373        for (msk, _) in &self.master_keys {
374            func_keys.push(ipfe_keygen(msk, func_vector)?);
375        }
376
377        Ok(func_keys)
378    }
379
380    /// Aggregate decrypt multiple ciphertexts from different clients
381    pub fn aggregate_decrypt(
382        func_keys: &[IpfeFunctionalKey],
383        ciphertexts: &[IpfeCiphertext],
384    ) -> FunctionalEncryptionResult<i64> {
385        if func_keys.len() != ciphertexts.len() {
386            return Err(FunctionalEncryptionError::InvalidInput(
387                "number of keys and ciphertexts must match".to_string(),
388            ));
389        }
390
391        // Sum all individual results
392        let mut total = 0i64;
393        for (fk, ct) in func_keys.iter().zip(ciphertexts.iter()) {
394            total += ipfe_decrypt(fk, ct)?;
395        }
396
397        Ok(total)
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_ipfe_basic() {
407        let (msk, mpk) = ipfe_setup(3);
408
409        let plaintext = vec![5, 10, 15];
410        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
411
412        let func_vector = vec![1, 2, 3];
413        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
414
415        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
416        assert_eq!(result, 5 + 10 * 2 + 15 * 3); // 70
417    }
418
419    #[test]
420    fn test_ipfe_negative_values() {
421        let (msk, mpk) = ipfe_setup(4);
422
423        let plaintext = vec![10, -5, 8, -3];
424        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
425
426        let func_vector = vec![2, 1, -1, 4];
427        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
428
429        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
430        assert_eq!(result, 10 * 2 + (-5) + -8 + (-3) * 4); // 20 - 5 - 8 - 12 = -5
431    }
432
433    #[test]
434    fn test_ipfe_zero_vector() {
435        let (msk, mpk) = ipfe_setup(3);
436
437        let plaintext = vec![0, 0, 0];
438        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
439
440        let func_vector = vec![1, 2, 3];
441        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
442
443        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
444        assert_eq!(result, 0);
445    }
446
447    #[test]
448    fn test_ipfe_dimension_mismatch() {
449        let (msk, mpk) = ipfe_setup(3);
450
451        let plaintext = vec![1, 2];
452        let result = ipfe_encrypt(&mpk, &plaintext);
453        assert!(result.is_err());
454
455        let func_vector = vec![1, 2, 3, 4];
456        let result = ipfe_keygen(&msk, &func_vector);
457        assert!(result.is_err());
458    }
459
460    #[test]
461    fn test_ipfe_multiple_keys() {
462        let (msk, mpk) = ipfe_setup(3);
463
464        let plaintext = vec![4, 5, 6];
465        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
466
467        // First functional key
468        let func_vector1 = vec![1, 0, 0];
469        let func_key1 = ipfe_keygen(&msk, &func_vector1).unwrap();
470        let result1 = ipfe_decrypt(&func_key1, &ciphertext).unwrap();
471        assert_eq!(result1, 4);
472
473        // Second functional key
474        let func_vector2 = vec![0, 1, 0];
475        let func_key2 = ipfe_keygen(&msk, &func_vector2).unwrap();
476        let result2 = ipfe_decrypt(&func_key2, &ciphertext).unwrap();
477        assert_eq!(result2, 5);
478
479        // Third functional key
480        let func_vector3 = vec![0, 0, 1];
481        let func_key3 = ipfe_keygen(&msk, &func_vector3).unwrap();
482        let result3 = ipfe_decrypt(&func_key3, &ciphertext).unwrap();
483        assert_eq!(result3, 6);
484    }
485
486    #[test]
487    fn test_ipfe_serialization() {
488        let (msk, mpk) = ipfe_setup(3);
489
490        // Serialize and deserialize public key
491        let mpk_bytes = crate::codec::encode(&mpk).unwrap();
492        let mpk_restored: IpfeMasterPublicKey = crate::codec::decode(&mpk_bytes).unwrap();
493
494        // Serialize and deserialize secret key
495        let msk_bytes = crate::codec::encode(&msk).unwrap();
496        let msk_restored: IpfeMasterSecretKey = crate::codec::decode(&msk_bytes).unwrap();
497
498        // Test that restored keys work
499        let plaintext = vec![7, 8, 9];
500        let ciphertext = ipfe_encrypt(&mpk_restored, &plaintext).unwrap();
501
502        let func_vector = vec![1, 1, 1];
503        let func_key = ipfe_keygen(&msk_restored, &func_vector).unwrap();
504
505        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
506        assert_eq!(result, 24);
507    }
508
509    #[test]
510    fn test_multi_client_ipfe() {
511        let mcipfe = MultiClientIpfe::setup(3, 2);
512
513        // Each client encrypts their vector
514        let plaintext1 = vec![10, 20];
515        let plaintext2 = vec![5, 15];
516        let plaintext3 = vec![3, 7];
517
518        let ct1 = ipfe_encrypt(mcipfe.get_public_key(0).unwrap(), &plaintext1).unwrap();
519        let ct2 = ipfe_encrypt(mcipfe.get_public_key(1).unwrap(), &plaintext2).unwrap();
520        let ct3 = ipfe_encrypt(mcipfe.get_public_key(2).unwrap(), &plaintext3).unwrap();
521
522        // Generate functional keys for computing weighted sum
523        let func_vector = vec![2, 1];
524        let func_keys = mcipfe.keygen(&func_vector).unwrap();
525
526        // Aggregate decrypt
527        let result = MultiClientIpfe::aggregate_decrypt(&func_keys, &[ct1, ct2, ct3]).unwrap();
528
529        // Expected: (10*2 + 20*1) + (5*2 + 15*1) + (3*2 + 7*1) = 40 + 25 + 13 = 78
530        assert_eq!(result, 78);
531    }
532
533    #[test]
534    fn test_multi_client_dimension_mismatch() {
535        let mcipfe = MultiClientIpfe::setup(2, 3);
536
537        let func_vector = vec![1, 2];
538        let result = mcipfe.keygen(&func_vector);
539        assert!(result.is_err());
540    }
541
542    #[test]
543    fn test_multi_client_aggregate_mismatch() {
544        let mcipfe = MultiClientIpfe::setup(2, 2);
545
546        let plaintext = vec![1, 2];
547        let ct1 = ipfe_encrypt(mcipfe.get_public_key(0).unwrap(), &plaintext).unwrap();
548
549        let func_vector = vec![1, 1];
550        let func_keys = mcipfe.keygen(&func_vector).unwrap();
551
552        // Only one ciphertext but two keys
553        let result = MultiClientIpfe::aggregate_decrypt(&func_keys, &[ct1]);
554        assert!(result.is_err());
555    }
556
557    #[test]
558    fn test_ipfe_large_dimension() {
559        let dimension = 10;
560        let (msk, mpk) = ipfe_setup(dimension);
561
562        let plaintext: Vec<i64> = (1..=dimension as i64).collect();
563        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
564
565        let func_vector = vec![1; dimension];
566        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
567
568        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
569        let expected: i64 = (1..=dimension as i64).sum();
570        assert_eq!(result, expected);
571    }
572
573    #[test]
574    fn test_functional_key_serialization() {
575        let (msk, mpk) = ipfe_setup(3);
576
577        let func_vector = vec![2, 3, 4];
578        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
579
580        // Serialize and deserialize functional key
581        let fk_bytes = crate::codec::encode(&func_key).unwrap();
582        let fk_restored: IpfeFunctionalKey = crate::codec::decode(&fk_bytes).unwrap();
583
584        // Test that it works
585        let plaintext = vec![1, 2, 3];
586        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
587
588        let result = ipfe_decrypt(&fk_restored, &ciphertext).unwrap();
589        assert_eq!(result, 2 + 2 * 3 + 3 * 4); // 20
590    }
591
592    #[test]
593    fn test_ciphertext_serialization() {
594        let (msk, mpk) = ipfe_setup(3);
595
596        let plaintext = vec![5, 6, 7];
597        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
598
599        // Serialize and deserialize ciphertext
600        let ct_bytes = crate::codec::encode(&ciphertext).unwrap();
601        let ct_restored: IpfeCiphertext = crate::codec::decode(&ct_bytes).unwrap();
602
603        // Test that it works
604        let func_vector = vec![1, 2, 1];
605        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
606
607        let result = ipfe_decrypt(&func_key, &ct_restored).unwrap();
608        assert_eq!(result, 5 + 6 * 2 + 7); // 24
609    }
610
611    #[test]
612    fn test_ipfe_single_dimension() {
613        let (msk, mpk) = ipfe_setup(1);
614
615        let plaintext = vec![42];
616        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
617
618        let func_vector = vec![3];
619        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
620
621        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
622        assert_eq!(result, 42 * 3);
623    }
624
625    #[test]
626    fn test_ipfe_orthogonal_vectors() {
627        let (msk, mpk) = ipfe_setup(3);
628
629        let plaintext = vec![1, 0, 0];
630        let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
631
632        // Orthogonal function vector
633        let func_vector = vec![0, 1, 0];
634        let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
635
636        let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
637        assert_eq!(result, 0); // Orthogonal vectors give zero
638    }
639}