use num_bigint::{BigUint, RandBigInt};
use num_prime::RandPrime;
use num_traits::{One, Zero};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PaillierPublicKey {
pub n: BigUint,
pub n_squared: BigUint,
pub g: BigUint,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PaillierPrivateKey {
lambda: BigUint,
mu: BigUint,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PaillierKeypair {
pub public_key: PaillierPublicKey,
pub private_key: PaillierPrivateKey,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct PaillierCiphertext {
pub c: BigUint,
}
impl PaillierKeypair {
pub fn generate(bits: usize) -> Self {
let mut rng = rand_core06::OsRng;
let p: BigUint = rng.gen_prime(bits / 2, None);
let q: BigUint = rng.gen_prime(bits / 2, None);
let n = &p * &q;
let n_squared = &n * &n;
let g: BigUint = &n + BigUint::one();
let p_minus_1 = &p - BigUint::one();
let q_minus_1 = &q - BigUint::one();
let lambda = lcm(&p_minus_1, &q_minus_1);
let g_lambda = g.modpow(&lambda, &n_squared);
let l_value = l_function(&g_lambda, &n);
let mu = mod_inverse(&l_value, &n);
Self {
public_key: PaillierPublicKey { n, n_squared, g },
private_key: PaillierPrivateKey { lambda, mu },
}
}
pub fn to_bytes(&self) -> Vec<u8> {
crate::codec::encode(self).expect("serialization failed")
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
Ok(crate::codec::decode(bytes)?)
}
}
impl PaillierCiphertext {
pub fn add(&self, other: &PaillierCiphertext, public_key: &PaillierPublicKey) -> Self {
let c = (&self.c * &other.c) % &public_key.n_squared;
PaillierCiphertext { c }
}
pub fn mul_scalar(&self, scalar: u64, public_key: &PaillierPublicKey) -> Self {
let k = BigUint::from(scalar);
let c = self.c.modpow(&k, &public_key.n_squared);
PaillierCiphertext { c }
}
pub fn to_bytes(&self) -> Vec<u8> {
crate::codec::encode(self).expect("serialization failed")
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
Ok(crate::codec::decode(bytes)?)
}
}
pub fn encrypt(public_key: &PaillierPublicKey, message: u64) -> PaillierCiphertext {
let mut rng = rand_core06::OsRng;
let m = BigUint::from(message);
let r = loop {
let candidate = rng.gen_biguint_below(&public_key.n);
if gcd(&candidate, &public_key.n) == BigUint::one() {
break candidate;
}
};
let g_m = public_key.g.modpow(&m, &public_key.n_squared);
let r_n = r.modpow(&public_key.n, &public_key.n_squared);
let c = (g_m * r_n) % &public_key.n_squared;
PaillierCiphertext { c }
}
pub fn decrypt(keypair: &PaillierKeypair, ciphertext: &PaillierCiphertext) -> u64 {
let pk = &keypair.public_key;
let sk = &keypair.private_key;
let c_lambda = ciphertext.c.modpow(&sk.lambda, &pk.n_squared);
let l_value = l_function(&c_lambda, &pk.n);
let m = (l_value * &sk.mu) % &pk.n;
m.to_u64_digits().first().copied().unwrap_or(0)
}
fn l_function(x: &BigUint, n: &BigUint) -> BigUint {
(x - BigUint::one()) / n
}
fn gcd(a: &BigUint, b: &BigUint) -> BigUint {
let mut a = a.clone();
let mut b = b.clone();
while !b.is_zero() {
let temp = b.clone();
b = &a % &b;
a = temp;
}
a
}
fn lcm(a: &BigUint, b: &BigUint) -> BigUint {
(a * b) / gcd(a, b)
}
fn mod_inverse(a: &BigUint, m: &BigUint) -> BigUint {
let (mut t, mut new_t) = (BigUint::zero(), BigUint::one());
let (mut r, mut new_r) = (m.clone(), a.clone());
while !new_r.is_zero() {
let quotient = &r / &new_r;
let temp_t = new_t.clone();
new_t = if t >= "ient * &new_t {
&t - "ient * &new_t
} else {
m - ("ient * &new_t - &t) % m
};
t = temp_t;
let temp_r = new_r.clone();
new_r = &r - "ient * &new_r;
r = temp_r;
}
if r > BigUint::one() {
panic!("a is not invertible");
}
t % m
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_paillier_basic() {
let keypair = PaillierKeypair::generate(512);
let message = 42u64;
let ciphertext = encrypt(&keypair.public_key, message);
let decrypted = decrypt(&keypair, &ciphertext);
assert_eq!(decrypted, message);
}
#[test]
fn test_homomorphic_addition() {
let keypair = PaillierKeypair::generate(512);
let m1 = 100u64;
let m2 = 50u64;
let c1 = encrypt(&keypair.public_key, m1);
let c2 = encrypt(&keypair.public_key, m2);
let c_sum = c1.add(&c2, &keypair.public_key);
let result = decrypt(&keypair, &c_sum);
assert_eq!(result, m1 + m2);
}
#[test]
fn test_homomorphic_scalar_multiplication() {
let keypair = PaillierKeypair::generate(512);
let m = 100u64;
let k = 3u64;
let c = encrypt(&keypair.public_key, m);
let c_mul = c.mul_scalar(k, &keypair.public_key);
let result = decrypt(&keypair, &c_mul);
assert_eq!(result, m * k);
}
#[test]
fn test_multiple_additions() {
let keypair = PaillierKeypair::generate(512);
let values = [10u64, 20, 30, 40, 50];
let expected_sum: u64 = values.iter().sum();
let ciphertexts: Vec<_> = values
.iter()
.map(|&v| encrypt(&keypair.public_key, v))
.collect();
let mut c_sum = ciphertexts[0].clone();
for c in &ciphertexts[1..] {
c_sum = c_sum.add(c, &keypair.public_key);
}
let result = decrypt(&keypair, &c_sum);
assert_eq!(result, expected_sum);
}
#[test]
fn test_combined_operations() {
let keypair = PaillierKeypair::generate(512);
let c1 = encrypt(&keypair.public_key, 10);
let c2 = encrypt(&keypair.public_key, 20);
let c1_scaled = c1.mul_scalar(2, &keypair.public_key);
let c2_scaled = c2.mul_scalar(3, &keypair.public_key);
let c_result = c1_scaled.add(&c2_scaled, &keypair.public_key);
let result = decrypt(&keypair, &c_result);
assert_eq!(result, 2 * 10 + 3 * 20);
}
#[test]
fn test_zero_encryption() {
let keypair = PaillierKeypair::generate(512);
let c = encrypt(&keypair.public_key, 0);
let result = decrypt(&keypair, &c);
assert_eq!(result, 0);
}
#[test]
fn test_deterministic_keypair() {
let kp1 = PaillierKeypair::generate(512);
let kp2 = PaillierKeypair::generate(512);
assert_ne!(kp1.public_key.n, kp2.public_key.n);
}
#[test]
fn test_encryption_randomness() {
let keypair = PaillierKeypair::generate(512);
let message = 42u64;
let c1 = encrypt(&keypair.public_key, message);
let c2 = encrypt(&keypair.public_key, message);
assert_ne!(c1.c, c2.c);
assert_eq!(decrypt(&keypair, &c1), message);
assert_eq!(decrypt(&keypair, &c2), message);
}
#[test]
fn test_large_values() {
let keypair = PaillierKeypair::generate(512);
let m1 = 1_000_000u64;
let m2 = 2_000_000u64;
let c1 = encrypt(&keypair.public_key, m1);
let c2 = encrypt(&keypair.public_key, m2);
let c_sum = c1.add(&c2, &keypair.public_key);
let result = decrypt(&keypair, &c_sum);
assert_eq!(result, m1 + m2);
}
#[test]
fn test_keypair_serialization() {
let keypair = PaillierKeypair::generate(512);
let bytes = keypair.to_bytes();
let restored = PaillierKeypair::from_bytes(&bytes).unwrap();
let message = 123u64;
let c = encrypt(&restored.public_key, message);
let result = decrypt(&restored, &c);
assert_eq!(result, message);
}
#[test]
fn test_ciphertext_serialization() {
let keypair = PaillierKeypair::generate(512);
let message = 456u64;
let c = encrypt(&keypair.public_key, message);
let bytes = c.to_bytes();
let restored = PaillierCiphertext::from_bytes(&bytes).unwrap();
assert_eq!(c, restored);
let result = decrypt(&keypair, &restored);
assert_eq!(result, message);
}
#[test]
fn test_addition_commutativity() {
let keypair = PaillierKeypair::generate(512);
let m1 = 100u64;
let m2 = 200u64;
let c1 = encrypt(&keypair.public_key, m1);
let c2 = encrypt(&keypair.public_key, m2);
let sum1 = c1.add(&c2, &keypair.public_key);
let sum2 = c2.add(&c1, &keypair.public_key);
let result1 = decrypt(&keypair, &sum1);
let result2 = decrypt(&keypair, &sum2);
assert_eq!(result1, result2);
assert_eq!(result1, m1 + m2);
}
#[test]
fn test_addition_associativity() {
let keypair = PaillierKeypair::generate(512);
let m1 = 10u64;
let m2 = 20u64;
let m3 = 30u64;
let c1 = encrypt(&keypair.public_key, m1);
let c2 = encrypt(&keypair.public_key, m2);
let c3 = encrypt(&keypair.public_key, m3);
let sum1 = c1.add(&c2, &keypair.public_key);
let sum1 = sum1.add(&c3, &keypair.public_key);
let sum2 = c2.add(&c3, &keypair.public_key);
let sum2 = c1.add(&sum2, &keypair.public_key);
let result1 = decrypt(&keypair, &sum1);
let result2 = decrypt(&keypair, &sum2);
assert_eq!(result1, result2);
assert_eq!(result1, m1 + m2 + m3);
}
#[test]
fn test_scalar_distributivity() {
let keypair = PaillierKeypair::generate(512);
let m1 = 10u64;
let m2 = 20u64;
let k = 3u64;
let c1 = encrypt(&keypair.public_key, m1);
let c2 = encrypt(&keypair.public_key, m2);
let sum = c1.add(&c2, &keypair.public_key);
let scaled_sum = sum.mul_scalar(k, &keypair.public_key);
let c1_scaled = c1.mul_scalar(k, &keypair.public_key);
let c2_scaled = c2.mul_scalar(k, &keypair.public_key);
let sum_scaled = c1_scaled.add(&c2_scaled, &keypair.public_key);
let result1 = decrypt(&keypair, &scaled_sum);
let result2 = decrypt(&keypair, &sum_scaled);
assert_eq!(result1, result2);
assert_eq!(result1, k * (m1 + m2));
}
#[test]
fn test_bandwidth_aggregation_use_case() {
let keypair = PaillierKeypair::generate(512);
let peer1_bandwidth = 1024u64; let peer2_bandwidth = 2048u64; let peer3_bandwidth = 4096u64;
let c1 = encrypt(&keypair.public_key, peer1_bandwidth);
let c2 = encrypt(&keypair.public_key, peer2_bandwidth);
let c3 = encrypt(&keypair.public_key, peer3_bandwidth);
let c_total = c1.add(&c2, &keypair.public_key);
let c_total = c_total.add(&c3, &keypair.public_key);
let total_bandwidth = decrypt(&keypair, &c_total);
assert_eq!(
total_bandwidth,
peer1_bandwidth + peer2_bandwidth + peer3_bandwidth
);
assert_eq!(total_bandwidth, 7168); }
}