use chacha20poly1305::{
aead::{Aead, Error as AError, NewAead},
ChaCha20Poly1305, Key,
};
pub use chacha20poly1305::aead::Error as EciesError;
use hkdf::Hkdf;
use rand_core::RngCore;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use crate::group::{Curve, Element};
const NONCE_LEN: usize = 12;
const KEY_LEN: usize = 32;
const DOMAIN: [u8; 4] = [1, 9, 6, 9];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EciesCipher<C: Curve> {
aead: Vec<u8>,
ephemeral: C::Point,
nonce: [u8; NONCE_LEN],
}
pub fn encrypt<C: Curve, R: RngCore>(to: &C::Point, msg: &[u8], rng: &mut R) -> EciesCipher<C> {
let eph_secret = C::Scalar::rand(rng);
let mut ephemeral = C::Point::one();
ephemeral.mul(&eph_secret);
let mut dh = to.clone();
dh.mul(&eph_secret);
let ephemeral_key = derive::<C>(&dh);
let aead = ChaCha20Poly1305::new(Key::from_slice(ephemeral_key.as_slice()));
let mut nonce: [u8; NONCE_LEN] = [0u8; NONCE_LEN];
rng.fill_bytes(&mut nonce);
let aead = aead
.encrypt(&nonce.into(), msg)
.expect("aead should not fail");
EciesCipher {
aead,
nonce,
ephemeral,
}
}
pub fn decrypt<C: Curve>(private: &C::Scalar, cipher: &EciesCipher<C>) -> Result<Vec<u8>, AError> {
let mut dh = cipher.ephemeral.clone();
dh.mul(private);
let ephemeral_key = derive::<C>(&dh);
let aead = ChaCha20Poly1305::new(Key::from_slice(ephemeral_key.as_slice()));
aead.decrypt(&cipher.nonce.into(), &cipher.aead[..])
}
fn derive<C: Curve>(dh: &C::Point) -> [u8; KEY_LEN] {
let serialized = bincode::serialize(dh).expect("could not serialize element");
let h = Hkdf::<Sha256>::new(None, &serialized);
let mut ephemeral_key = [0u8; KEY_LEN];
h.expand(&DOMAIN, &mut ephemeral_key)
.expect("hkdf should not fail");
debug_assert!(ephemeral_key.len() == KEY_LEN);
ephemeral_key
}
#[cfg(feature = "bls12_381")]
#[cfg(test)]
mod tests {
use rand::thread_rng;
use crate::curve::bls12381::{G1Curve as Curve, Scalar, G1};
use super::*;
fn kp() -> (Scalar, G1) {
let secret = Scalar::rand(&mut thread_rng());
let mut public = G1::one();
public.mul(&secret);
(secret, public)
}
#[test]
fn test_decryption() {
let (s1, _) = kp();
let (s2, p2) = kp();
let data = vec![1, 2, 3, 4];
let mut cipher = encrypt::<Curve, _>(&p2, &data, &mut thread_rng());
let deciphered = decrypt::<Curve>(&s2, &cipher).unwrap();
assert_eq!(data, deciphered);
decrypt::<Curve>(&s1, &cipher).unwrap_err();
cipher.aead = vec![0; 32];
decrypt::<Curve>(&s2, &cipher).unwrap_err();
}
}