1use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
32use curve25519_dalek::scalar::Scalar;
33use curve25519_dalek::traits::Identity;
34use serde::{Deserialize, Serialize};
35use sha2::{Digest, Sha256};
36use thiserror::Error;
37
38#[derive(Error, Debug)]
40pub enum FunctionalEncryptionError {
41 #[error("Invalid input: {0}")]
42 InvalidInput(String),
43 #[error("Decryption failed: {0}")]
44 DecryptionFailed(String),
45 #[error("Serialization error: {0}")]
46 SerializationError(String),
47}
48
49pub type FunctionalEncryptionResult<T> = Result<T, FunctionalEncryptionError>;
51
52#[derive(Clone, Serialize, Deserialize)]
54pub struct IpfeMasterSecretKey {
55 secret_scalars: Vec<Scalar>,
57}
58
59#[derive(Clone, Serialize, Deserialize)]
61pub struct IpfeMasterPublicKey {
62 #[serde(with = "serde_point_vec")]
64 public_points: Vec<RistrettoPoint>,
65 #[serde(with = "serde_point")]
67 generator: RistrettoPoint,
68}
69
70#[derive(Clone, Serialize, Deserialize)]
72pub struct IpfeFunctionalKey {
73 functional_scalar: Scalar,
75 func_vector: Vec<i64>,
77}
78
79#[derive(Clone, Serialize, Deserialize)]
81pub struct IpfeCiphertext {
82 #[serde(with = "serde_point")]
84 c0: RistrettoPoint,
85 #[serde(with = "serde_point_vec")]
87 encrypted_points: Vec<RistrettoPoint>,
88}
89
90mod serde_point {
92 use super::*;
93 use serde::{Deserializer, Serializer};
94
95 pub fn serialize<S>(point: &RistrettoPoint, serializer: S) -> Result<S::Ok, S::Error>
96 where
97 S: Serializer,
98 {
99 let bytes = point.compress().to_bytes();
100 serializer.serialize_bytes(&bytes)
101 }
102
103 pub fn deserialize<'de, D>(deserializer: D) -> Result<RistrettoPoint, D::Error>
104 where
105 D: Deserializer<'de>,
106 {
107 let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
108 if bytes.len() != 32 {
109 return Err(serde::de::Error::custom("invalid point length"));
110 }
111 let mut arr = [0u8; 32];
112 arr.copy_from_slice(&bytes);
113 CompressedRistretto(arr)
114 .decompress()
115 .ok_or_else(|| serde::de::Error::custom("invalid point"))
116 }
117}
118
119mod serde_point_vec {
120 use super::*;
121 use serde::{Deserializer, Serializer};
122
123 pub fn serialize<S>(points: &[RistrettoPoint], serializer: S) -> Result<S::Ok, S::Error>
124 where
125 S: Serializer,
126 {
127 let bytes: Vec<Vec<u8>> = points
128 .iter()
129 .map(|p| p.compress().to_bytes().to_vec())
130 .collect();
131 bytes.serialize(serializer)
132 }
133
134 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<RistrettoPoint>, D::Error>
135 where
136 D: Deserializer<'de>,
137 {
138 let bytes_vec: Vec<Vec<u8>> = Deserialize::deserialize(deserializer)?;
139 bytes_vec
140 .into_iter()
141 .map(|bytes| {
142 if bytes.len() != 32 {
143 return Err(serde::de::Error::custom("invalid point length"));
144 }
145 let mut arr = [0u8; 32];
146 arr.copy_from_slice(&bytes);
147 CompressedRistretto(arr)
148 .decompress()
149 .ok_or_else(|| serde::de::Error::custom("invalid point"))
150 })
151 .collect()
152 }
153}
154
155pub fn ipfe_setup(dimension: usize) -> (IpfeMasterSecretKey, IpfeMasterPublicKey) {
163 let generator = curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
164
165 let mut secret_scalars = Vec::with_capacity(dimension);
166 let mut public_points = Vec::with_capacity(dimension);
167
168 for i in 0..dimension {
169 let mut hasher = Sha256::new();
171 hasher.update(b"ipfe_master_secret");
172 hasher.update(i.to_le_bytes());
173 hasher.update(rand::random::<[u8; 32]>());
174 let hash = hasher.finalize();
175 let scalar = Scalar::from_bytes_mod_order(hash.into());
176
177 let public_point = generator * scalar;
179
180 secret_scalars.push(scalar);
181 public_points.push(public_point);
182 }
183
184 let msk = IpfeMasterSecretKey { secret_scalars };
185 let mpk = IpfeMasterPublicKey {
186 public_points,
187 generator,
188 };
189
190 (msk, mpk)
191}
192
193pub fn ipfe_encrypt(
202 mpk: &IpfeMasterPublicKey,
203 plaintext: &[i64],
204) -> FunctionalEncryptionResult<IpfeCiphertext> {
205 if plaintext.len() != mpk.public_points.len() {
206 return Err(FunctionalEncryptionError::InvalidInput(
207 "plaintext dimension mismatch".to_string(),
208 ));
209 }
210
211 let r = Scalar::from_bytes_mod_order(rand::random::<[u8; 32]>());
213
214 let c0 = mpk.generator * r;
216
217 let mut encrypted_points = Vec::with_capacity(plaintext.len());
218
219 for (i, &value) in plaintext.iter().enumerate() {
220 let value_scalar = Scalar::from(value.unsigned_abs());
222 let value_scalar = if value < 0 {
223 -value_scalar
224 } else {
225 value_scalar
226 };
227
228 let encrypted = (mpk.public_points[i] * r) + (mpk.generator * value_scalar);
230 encrypted_points.push(encrypted);
231 }
232
233 Ok(IpfeCiphertext {
234 c0,
235 encrypted_points,
236 })
237}
238
239pub fn ipfe_keygen(
248 msk: &IpfeMasterSecretKey,
249 func_vector: &[i64],
250) -> FunctionalEncryptionResult<IpfeFunctionalKey> {
251 if func_vector.len() != msk.secret_scalars.len() {
252 return Err(FunctionalEncryptionError::InvalidInput(
253 "function vector dimension mismatch".to_string(),
254 ));
255 }
256
257 let mut functional_scalar = Scalar::ZERO;
259
260 for (i, &value) in func_vector.iter().enumerate() {
261 let value_scalar = Scalar::from(value.unsigned_abs());
262 let value_scalar = if value < 0 {
263 -value_scalar
264 } else {
265 value_scalar
266 };
267
268 functional_scalar += value_scalar * msk.secret_scalars[i];
269 }
270
271 Ok(IpfeFunctionalKey {
272 functional_scalar,
273 func_vector: func_vector.to_vec(),
274 })
275}
276
277pub fn ipfe_decrypt(
286 func_key: &IpfeFunctionalKey,
287 ciphertext: &IpfeCiphertext,
288) -> FunctionalEncryptionResult<i64> {
289 if func_key.func_vector.len() != ciphertext.encrypted_points.len() {
291 return Err(FunctionalEncryptionError::InvalidInput(
292 "function vector and ciphertext dimension mismatch".to_string(),
293 ));
294 }
295
296 let generator = curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
304 let mut result_point = RistrettoPoint::identity();
305
306 for (i, &y_i) in func_key.func_vector.iter().enumerate() {
308 let y_scalar = Scalar::from(y_i.unsigned_abs());
309 let y_scalar = if y_i < 0 { -y_scalar } else { y_scalar };
310
311 result_point += ciphertext.encrypted_points[i] * y_scalar;
312 }
313
314 result_point -= ciphertext.c0 * func_key.functional_scalar;
316
317 for i in 0..=10000 {
321 if result_point == generator * Scalar::from(i as u64) {
322 return Ok(i);
323 }
324 if result_point == generator * (-Scalar::from(i as u64)) {
325 return Ok(-i);
326 }
327 }
328
329 Err(FunctionalEncryptionError::DecryptionFailed(
330 "discrete log too large".to_string(),
331 ))
332}
333
334pub struct MultiClientIpfe {
336 dimension: usize,
337 master_keys: Vec<(IpfeMasterSecretKey, IpfeMasterPublicKey)>,
338}
339
340impl MultiClientIpfe {
341 pub fn setup(num_clients: usize, dimension: usize) -> Self {
343 let mut master_keys = Vec::with_capacity(num_clients);
344
345 for _ in 0..num_clients {
346 master_keys.push(ipfe_setup(dimension));
347 }
348
349 Self {
350 dimension,
351 master_keys,
352 }
353 }
354
355 pub fn get_public_key(&self, client_id: usize) -> Option<&IpfeMasterPublicKey> {
357 self.master_keys.get(client_id).map(|(_, mpk)| mpk)
358 }
359
360 pub fn keygen(
362 &self,
363 func_vector: &[i64],
364 ) -> FunctionalEncryptionResult<Vec<IpfeFunctionalKey>> {
365 if func_vector.len() != self.dimension {
366 return Err(FunctionalEncryptionError::InvalidInput(
367 "function vector dimension mismatch".to_string(),
368 ));
369 }
370
371 let mut func_keys = Vec::with_capacity(self.master_keys.len());
372
373 for (msk, _) in &self.master_keys {
374 func_keys.push(ipfe_keygen(msk, func_vector)?);
375 }
376
377 Ok(func_keys)
378 }
379
380 pub fn aggregate_decrypt(
382 func_keys: &[IpfeFunctionalKey],
383 ciphertexts: &[IpfeCiphertext],
384 ) -> FunctionalEncryptionResult<i64> {
385 if func_keys.len() != ciphertexts.len() {
386 return Err(FunctionalEncryptionError::InvalidInput(
387 "number of keys and ciphertexts must match".to_string(),
388 ));
389 }
390
391 let mut total = 0i64;
393 for (fk, ct) in func_keys.iter().zip(ciphertexts.iter()) {
394 total += ipfe_decrypt(fk, ct)?;
395 }
396
397 Ok(total)
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_ipfe_basic() {
407 let (msk, mpk) = ipfe_setup(3);
408
409 let plaintext = vec![5, 10, 15];
410 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
411
412 let func_vector = vec![1, 2, 3];
413 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
414
415 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
416 assert_eq!(result, 5 + 10 * 2 + 15 * 3); }
418
419 #[test]
420 fn test_ipfe_negative_values() {
421 let (msk, mpk) = ipfe_setup(4);
422
423 let plaintext = vec![10, -5, 8, -3];
424 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
425
426 let func_vector = vec![2, 1, -1, 4];
427 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
428
429 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
430 assert_eq!(result, 10 * 2 + (-5) + -8 + (-3) * 4); }
432
433 #[test]
434 fn test_ipfe_zero_vector() {
435 let (msk, mpk) = ipfe_setup(3);
436
437 let plaintext = vec![0, 0, 0];
438 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
439
440 let func_vector = vec![1, 2, 3];
441 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
442
443 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
444 assert_eq!(result, 0);
445 }
446
447 #[test]
448 fn test_ipfe_dimension_mismatch() {
449 let (msk, mpk) = ipfe_setup(3);
450
451 let plaintext = vec![1, 2];
452 let result = ipfe_encrypt(&mpk, &plaintext);
453 assert!(result.is_err());
454
455 let func_vector = vec![1, 2, 3, 4];
456 let result = ipfe_keygen(&msk, &func_vector);
457 assert!(result.is_err());
458 }
459
460 #[test]
461 fn test_ipfe_multiple_keys() {
462 let (msk, mpk) = ipfe_setup(3);
463
464 let plaintext = vec![4, 5, 6];
465 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
466
467 let func_vector1 = vec![1, 0, 0];
469 let func_key1 = ipfe_keygen(&msk, &func_vector1).unwrap();
470 let result1 = ipfe_decrypt(&func_key1, &ciphertext).unwrap();
471 assert_eq!(result1, 4);
472
473 let func_vector2 = vec![0, 1, 0];
475 let func_key2 = ipfe_keygen(&msk, &func_vector2).unwrap();
476 let result2 = ipfe_decrypt(&func_key2, &ciphertext).unwrap();
477 assert_eq!(result2, 5);
478
479 let func_vector3 = vec![0, 0, 1];
481 let func_key3 = ipfe_keygen(&msk, &func_vector3).unwrap();
482 let result3 = ipfe_decrypt(&func_key3, &ciphertext).unwrap();
483 assert_eq!(result3, 6);
484 }
485
486 #[test]
487 fn test_ipfe_serialization() {
488 let (msk, mpk) = ipfe_setup(3);
489
490 let mpk_bytes = crate::codec::encode(&mpk).unwrap();
492 let mpk_restored: IpfeMasterPublicKey = crate::codec::decode(&mpk_bytes).unwrap();
493
494 let msk_bytes = crate::codec::encode(&msk).unwrap();
496 let msk_restored: IpfeMasterSecretKey = crate::codec::decode(&msk_bytes).unwrap();
497
498 let plaintext = vec![7, 8, 9];
500 let ciphertext = ipfe_encrypt(&mpk_restored, &plaintext).unwrap();
501
502 let func_vector = vec![1, 1, 1];
503 let func_key = ipfe_keygen(&msk_restored, &func_vector).unwrap();
504
505 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
506 assert_eq!(result, 24);
507 }
508
509 #[test]
510 fn test_multi_client_ipfe() {
511 let mcipfe = MultiClientIpfe::setup(3, 2);
512
513 let plaintext1 = vec![10, 20];
515 let plaintext2 = vec![5, 15];
516 let plaintext3 = vec![3, 7];
517
518 let ct1 = ipfe_encrypt(mcipfe.get_public_key(0).unwrap(), &plaintext1).unwrap();
519 let ct2 = ipfe_encrypt(mcipfe.get_public_key(1).unwrap(), &plaintext2).unwrap();
520 let ct3 = ipfe_encrypt(mcipfe.get_public_key(2).unwrap(), &plaintext3).unwrap();
521
522 let func_vector = vec![2, 1];
524 let func_keys = mcipfe.keygen(&func_vector).unwrap();
525
526 let result = MultiClientIpfe::aggregate_decrypt(&func_keys, &[ct1, ct2, ct3]).unwrap();
528
529 assert_eq!(result, 78);
531 }
532
533 #[test]
534 fn test_multi_client_dimension_mismatch() {
535 let mcipfe = MultiClientIpfe::setup(2, 3);
536
537 let func_vector = vec![1, 2];
538 let result = mcipfe.keygen(&func_vector);
539 assert!(result.is_err());
540 }
541
542 #[test]
543 fn test_multi_client_aggregate_mismatch() {
544 let mcipfe = MultiClientIpfe::setup(2, 2);
545
546 let plaintext = vec![1, 2];
547 let ct1 = ipfe_encrypt(mcipfe.get_public_key(0).unwrap(), &plaintext).unwrap();
548
549 let func_vector = vec![1, 1];
550 let func_keys = mcipfe.keygen(&func_vector).unwrap();
551
552 let result = MultiClientIpfe::aggregate_decrypt(&func_keys, &[ct1]);
554 assert!(result.is_err());
555 }
556
557 #[test]
558 fn test_ipfe_large_dimension() {
559 let dimension = 10;
560 let (msk, mpk) = ipfe_setup(dimension);
561
562 let plaintext: Vec<i64> = (1..=dimension as i64).collect();
563 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
564
565 let func_vector = vec![1; dimension];
566 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
567
568 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
569 let expected: i64 = (1..=dimension as i64).sum();
570 assert_eq!(result, expected);
571 }
572
573 #[test]
574 fn test_functional_key_serialization() {
575 let (msk, mpk) = ipfe_setup(3);
576
577 let func_vector = vec![2, 3, 4];
578 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
579
580 let fk_bytes = crate::codec::encode(&func_key).unwrap();
582 let fk_restored: IpfeFunctionalKey = crate::codec::decode(&fk_bytes).unwrap();
583
584 let plaintext = vec![1, 2, 3];
586 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
587
588 let result = ipfe_decrypt(&fk_restored, &ciphertext).unwrap();
589 assert_eq!(result, 2 + 2 * 3 + 3 * 4); }
591
592 #[test]
593 fn test_ciphertext_serialization() {
594 let (msk, mpk) = ipfe_setup(3);
595
596 let plaintext = vec![5, 6, 7];
597 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
598
599 let ct_bytes = crate::codec::encode(&ciphertext).unwrap();
601 let ct_restored: IpfeCiphertext = crate::codec::decode(&ct_bytes).unwrap();
602
603 let func_vector = vec![1, 2, 1];
605 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
606
607 let result = ipfe_decrypt(&func_key, &ct_restored).unwrap();
608 assert_eq!(result, 5 + 6 * 2 + 7); }
610
611 #[test]
612 fn test_ipfe_single_dimension() {
613 let (msk, mpk) = ipfe_setup(1);
614
615 let plaintext = vec![42];
616 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
617
618 let func_vector = vec![3];
619 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
620
621 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
622 assert_eq!(result, 42 * 3);
623 }
624
625 #[test]
626 fn test_ipfe_orthogonal_vectors() {
627 let (msk, mpk) = ipfe_setup(3);
628
629 let plaintext = vec![1, 0, 0];
630 let ciphertext = ipfe_encrypt(&mpk, &plaintext).unwrap();
631
632 let func_vector = vec![0, 1, 0];
634 let func_key = ipfe_keygen(&msk, &func_vector).unwrap();
635
636 let result = ipfe_decrypt(&func_key, &ciphertext).unwrap();
637 assert_eq!(result, 0); }
639}