use blake3::Hasher;
use chacha20poly1305::{
ChaCha20Poly1305, Nonce,
aead::{Aead, KeyInit},
};
use curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_TABLE,
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
};
use rand::RngExt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProxyReError {
InvalidCiphertext,
DecryptionFailed,
EncryptionFailed,
InvalidPublicKey,
InvalidReKey,
SerializationError,
}
impl std::fmt::Display for ProxyReError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProxyReError::InvalidCiphertext => write!(f, "Invalid ciphertext format"),
ProxyReError::DecryptionFailed => write!(f, "Decryption failed"),
ProxyReError::EncryptionFailed => write!(f, "Encryption failed"),
ProxyReError::InvalidPublicKey => write!(f, "Invalid public key"),
ProxyReError::InvalidReKey => write!(f, "Invalid re-encryption key"),
ProxyReError::SerializationError => write!(f, "Serialization/deserialization error"),
}
}
}
impl std::error::Error for ProxyReError {}
pub type ProxyReResult<T> = Result<T, ProxyReError>;
#[derive(Clone, Serialize, Deserialize)]
pub struct ProxyReSecretKey(Scalar);
impl ProxyReSecretKey {
pub fn generate() -> Self {
let mut rng = rand::rng();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes);
Self(Scalar::from_bytes_mod_order(bytes))
}
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
pub fn from_bytes(bytes: &[u8; 32]) -> Self {
Self(Scalar::from_bytes_mod_order(*bytes))
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProxyRePublicKey(RistrettoPoint);
impl ProxyRePublicKey {
pub fn from_secret(secret: &ProxyReSecretKey) -> Self {
Self(&secret.0 * RISTRETTO_BASEPOINT_TABLE)
}
pub fn to_bytes(&self) -> [u8; 32] {
self.0.compress().to_bytes()
}
pub fn from_bytes(bytes: &[u8; 32]) -> ProxyReResult<Self> {
CompressedRistretto(*bytes)
.decompress()
.map(Self)
.ok_or(ProxyReError::InvalidPublicKey)
}
}
#[derive(Clone)]
pub struct ProxyReKeypair {
secret: ProxyReSecretKey,
public: ProxyRePublicKey,
}
impl ProxyReKeypair {
pub fn generate() -> Self {
let secret = ProxyReSecretKey::generate();
let public = ProxyRePublicKey::from_secret(&secret);
Self { secret, public }
}
pub fn public_key(&self) -> ProxyRePublicKey {
self.public
}
pub fn secret_key(&self) -> &ProxyReSecretKey {
&self.secret
}
pub fn encrypt(&self, plaintext: &[u8]) -> ProxyReResult<ProxyReCiphertext> {
encrypt(&self.public, plaintext)
}
pub fn decrypt(&self, ciphertext: &ProxyReCiphertext) -> ProxyReResult<Vec<u8>> {
decrypt(&self.secret, ciphertext)
}
pub fn generate_re_key(&self, target_pk: &ProxyRePublicKey) -> ProxyReReKey {
generate_re_key(&self.secret, target_pk)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ProxyReReKey {
re_key: Scalar,
target_pk: ProxyRePublicKey,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ProxyReCiphertext {
ephemeral_pk: RistrettoPoint,
encrypted_key: Vec<u8>,
ciphertext: Vec<u8>,
nonce: [u8; 12],
}
impl ProxyReCiphertext {
pub fn to_bytes(&self) -> ProxyReResult<Vec<u8>> {
crate::codec::encode(self).map_err(|_| ProxyReError::SerializationError)
}
pub fn from_bytes(bytes: &[u8]) -> ProxyReResult<Self> {
crate::codec::decode(bytes).map_err(|_| ProxyReError::SerializationError)
}
}
pub fn encrypt(pk: &ProxyRePublicKey, plaintext: &[u8]) -> ProxyReResult<ProxyReCiphertext> {
let mut rng = rand::rng();
let ephemeral_sk = ProxyReSecretKey::generate();
let ephemeral_pk = ProxyRePublicKey::from_secret(&ephemeral_sk);
let shared_point = pk.0 * ephemeral_sk.0;
let sym_key = derive_symmetric_key(&shared_point);
let mut nonce_bytes = [0u8; 12];
rng.fill(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = ChaCha20Poly1305::new(&sym_key.into());
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|_| ProxyReError::EncryptionFailed)?;
let encrypted_key = vec![0u8; 32];
Ok(ProxyReCiphertext {
ephemeral_pk: ephemeral_pk.0,
encrypted_key,
ciphertext,
nonce: nonce_bytes,
})
}
pub fn decrypt(sk: &ProxyReSecretKey, ciphertext: &ProxyReCiphertext) -> ProxyReResult<Vec<u8>> {
let shared_point = ciphertext.ephemeral_pk * sk.0;
let sym_key = derive_symmetric_key(&shared_point);
let cipher = ChaCha20Poly1305::new(&sym_key.into());
let nonce = Nonce::from_slice(&ciphertext.nonce);
cipher
.decrypt(nonce, ciphertext.ciphertext.as_ref())
.map_err(|_| ProxyReError::DecryptionFailed)
}
pub fn generate_re_key(
delegator_sk: &ProxyReSecretKey,
delegatee_pk: &ProxyRePublicKey,
) -> ProxyReReKey {
let re_key = delegator_sk.0.invert();
ProxyReReKey {
re_key,
target_pk: *delegatee_pk,
}
}
pub fn re_encrypt(
ciphertext: &ProxyReCiphertext,
re_key: &ProxyReReKey,
) -> ProxyReResult<ProxyReCiphertext> {
let mut rng = rand::rng();
let re_ephemeral_sk = ProxyReSecretKey::generate();
let re_ephemeral_pk = ProxyRePublicKey::from_secret(&re_ephemeral_sk);
let new_shared_point = re_key.target_pk.0 * re_ephemeral_sk.0;
let new_sym_key = derive_symmetric_key(&new_shared_point);
let mut nonce_bytes = [0u8; 12];
rng.fill(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = ChaCha20Poly1305::new(&new_sym_key.into());
let original_serialized =
crate::codec::encode(ciphertext).map_err(|_| ProxyReError::SerializationError)?;
let new_ciphertext = cipher
.encrypt(nonce, original_serialized.as_ref())
.map_err(|_| ProxyReError::EncryptionFailed)?;
Ok(ProxyReCiphertext {
ephemeral_pk: re_ephemeral_pk.0,
encrypted_key: vec![1u8; 32], ciphertext: new_ciphertext,
nonce: nonce_bytes,
})
}
fn derive_symmetric_key(point: &RistrettoPoint) -> [u8; 32] {
let mut hasher = Hasher::new();
hasher.update(b"chie-proxy-re-v1");
hasher.update(&point.compress().to_bytes());
let hash = hasher.finalize();
*hash.as_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypair_generation() {
let keypair = ProxyReKeypair::generate();
let pk_derived = ProxyRePublicKey::from_secret(keypair.secret_key());
assert_eq!(pk_derived, keypair.public_key());
}
#[test]
fn test_basic_encryption_decryption() {
let keypair = ProxyReKeypair::generate();
let plaintext = b"Hello, proxy re-encryption!";
let ciphertext = keypair.encrypt(plaintext).unwrap();
let decrypted = keypair.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encryption_produces_different_ciphertexts() {
let keypair = ProxyReKeypair::generate();
let plaintext = b"Test message";
let ct1 = keypair.encrypt(plaintext).unwrap();
let ct2 = keypair.encrypt(plaintext).unwrap();
assert_ne!(ct1.ephemeral_pk.compress(), ct2.ephemeral_pk.compress());
assert_ne!(ct1.ciphertext, ct2.ciphertext);
}
#[test]
fn test_wrong_key_decryption_fails() {
let alice = ProxyReKeypair::generate();
let bob = ProxyReKeypair::generate();
let plaintext = b"Secret message";
let ciphertext = alice.encrypt(plaintext).unwrap();
assert!(bob.decrypt(&ciphertext).is_err());
}
#[test]
fn test_proxy_re_encryption() {
let alice = ProxyReKeypair::generate();
let bob = ProxyReKeypair::generate();
let plaintext = b"Delegated content";
let ciphertext = alice.encrypt(plaintext).unwrap();
let alice_decrypted = alice.decrypt(&ciphertext).unwrap();
assert_eq!(alice_decrypted, plaintext);
let re_key = alice.generate_re_key(&bob.public_key());
let re_encrypted = re_encrypt(&ciphertext, &re_key).unwrap();
let outer_decrypted = bob.decrypt(&re_encrypted).unwrap();
let inner_ciphertext: ProxyReCiphertext = crate::codec::decode(&outer_decrypted).unwrap();
let final_plaintext = alice.decrypt(&inner_ciphertext).unwrap();
assert_eq!(final_plaintext, plaintext);
}
#[test]
fn test_public_key_serialization() {
let keypair = ProxyReKeypair::generate();
let pk = keypair.public_key();
let bytes = pk.to_bytes();
let pk_restored = ProxyRePublicKey::from_bytes(&bytes).unwrap();
assert_eq!(pk, pk_restored);
}
#[test]
fn test_secret_key_serialization() {
let keypair = ProxyReKeypair::generate();
let sk = keypair.secret_key();
let bytes = sk.to_bytes();
let sk_restored = ProxyReSecretKey::from_bytes(&bytes);
let pk1 = ProxyRePublicKey::from_secret(sk);
let pk2 = ProxyRePublicKey::from_secret(&sk_restored);
assert_eq!(pk1, pk2);
}
#[test]
fn test_invalid_public_key() {
let invalid_bytes = [255u8; 32];
assert!(ProxyRePublicKey::from_bytes(&invalid_bytes).is_err());
}
#[test]
fn test_ciphertext_serialization() {
let keypair = ProxyReKeypair::generate();
let plaintext = b"Serialize this";
let ciphertext = keypair.encrypt(plaintext).unwrap();
let serialized = crate::codec::encode(&ciphertext).unwrap();
let deserialized: ProxyReCiphertext = crate::codec::decode(&serialized).unwrap();
let decrypted = keypair.decrypt(&deserialized).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_empty_plaintext() {
let keypair = ProxyReKeypair::generate();
let plaintext = b"";
let ciphertext = keypair.encrypt(plaintext).unwrap();
let decrypted = keypair.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_large_plaintext() {
let keypair = ProxyReKeypair::generate();
let plaintext = vec![42u8; 10_000];
let ciphertext = keypair.encrypt(&plaintext).unwrap();
let decrypted = keypair.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_multiple_delegations() {
let alice = ProxyReKeypair::generate();
let bob = ProxyReKeypair::generate();
let carol = ProxyReKeypair::generate();
let plaintext = b"Multi-hop delegation";
let ct_alice = alice.encrypt(plaintext).unwrap();
let re_key_alice_to_bob = alice.generate_re_key(&bob.public_key());
let ct_bob = re_encrypt(&ct_alice, &re_key_alice_to_bob).unwrap();
let re_key_alice_to_carol = alice.generate_re_key(&carol.public_key());
let ct_carol = re_encrypt(&ct_alice, &re_key_alice_to_carol).unwrap();
let bob_outer = bob.decrypt(&ct_bob).unwrap();
let carol_outer = carol.decrypt(&ct_carol).unwrap();
assert!(crate::codec::decode::<ProxyReCiphertext>(&bob_outer).is_ok());
assert!(crate::codec::decode::<ProxyReCiphertext>(&carol_outer).is_ok());
}
#[test]
fn test_re_key_serialization() {
let alice = ProxyReKeypair::generate();
let bob = ProxyReKeypair::generate();
let re_key = alice.generate_re_key(&bob.public_key());
let serialized = crate::codec::encode(&re_key).unwrap();
let deserialized: ProxyReReKey = crate::codec::decode(&serialized).unwrap();
assert_eq!(re_key.target_pk, deserialized.target_pk);
}
}