use crate::curve::{EmbeddedFr, EmbeddedGroupAffine, embedded};
use crate::curve::{FR_BYTES, Fr};
use crate::hash::transient_hash;
use crate::repr::{FieldRepr, FromFieldRepr};
use k256::elliptic_curve::subtle::CtOption;
#[cfg(feature = "proptest")]
use proptest_derive::Arbitrary;
use rand::distributions::Standard;
use rand::prelude::Distribution;
use rand::{CryptoRng, Rng};
use serde::{
de::{Deserialize, Deserializer},
ser::{Error, Serialize, Serializer},
};
#[cfg(feature = "proptest")]
use serialize::randomised_serialization_test;
use serialize::{Deserializable, Serializable, Tagged, tag_enforcement_test};
use std::fmt::{self, Debug, Formatter};
use std::iter::once;
use zeroize::Zeroize;
#[derive(Copy, Clone, Debug, Eq, Serializable)]
#[tag = "encryption-public-key[v1]"]
#[cfg_attr(feature = "proptest", derive(Arbitrary))]
pub struct PublicKey(EmbeddedGroupAffine);
impl Distribution<PublicKey> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> PublicKey {
PublicKey(rng.r#gen())
}
}
#[cfg(feature = "proptest")]
randomised_serialization_test!(PublicKey);
tag_enforcement_test!(PublicKey);
impl PartialEq for PublicKey {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Serialize for PublicKey {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut vec = Vec::new();
<PublicKey as Serializable>::serialize(self, &mut vec).map_err(S::Error::custom)?;
serializer.serialize_bytes(&vec)
}
}
impl<'de> Deserialize<'de> for PublicKey {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes = serde_bytes::ByteBuf::deserialize(deserializer)?;
<PublicKey as Deserializable>::deserialize(&mut &bytes[..], 0)
.map_err(serde::de::Error::custom)
}
}
#[cfg_attr(feature = "proptest", derive(Arbitrary))]
#[derive(Copy, Clone, Serializable, Zeroize)]
#[tag = "encryption-secret-key[v1]"]
pub struct SecretKey(EmbeddedFr);
impl Debug for SecretKey {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
write!(formatter, "<encryption secret key>")
}
}
tag_enforcement_test!(SecretKey);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Ciphertext {
pub c: EmbeddedGroupAffine,
pub ciph: Vec<Fr>,
}
impl Serializable for Ciphertext {
fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
Serializable::serialize(&self.c, writer)?;
Serializable::serialize(&self.ciph, writer)
}
fn serialized_size(&self) -> usize {
self.c.serialized_size() + self.ciph.serialized_size()
}
}
impl Deserializable for Ciphertext {
fn deserialize(
reader: &mut impl std::io::Read,
mut recursion_depth: u32,
) -> std::io::Result<Self> {
Ciphertext::check_rec(&mut recursion_depth)?;
let c = EmbeddedGroupAffine::deserialize(reader, recursion_depth)?;
let ciph = <Vec<Fr> as Deserializable>::deserialize(reader, recursion_depth)?;
if c.is_identity() {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"ciphertext challenge may not be the identity element",
))
} else {
Ok(Ciphertext { c, ciph })
}
}
}
impl PublicKey {
pub fn encrypt<R: Rng + CryptoRng + ?Sized, T: FieldRepr>(
&self,
rng: &mut R,
msg: &T,
) -> Ciphertext {
let y: EmbeddedFr = rng.r#gen();
let c = EmbeddedGroupAffine::generator() * y;
let k_star = self.0 * y;
let coords = if k_star.is_infinity() {
(0.into(), 0.into())
} else {
(k_star.x().unwrap(), k_star.y().unwrap())
};
let k = transient_hash(&[coords.0, coords.1]);
let ciph = once(0.into())
.chain(msg.field_vec())
.enumerate()
.map(|(ctr, msg)| transient_hash(&[k, (ctr as u64).into()]) + msg)
.collect();
Ciphertext { c, ciph }
}
}
impl SecretKey {
pub const BYTES: usize = FR_BYTES;
pub fn new<R: Rng + CryptoRng + ?Sized>(rng: &mut R) -> Self {
SecretKey(rng.r#gen())
}
pub fn from_uniform_bytes(bytes: &[u8; 64]) -> Self {
let value = embedded::Scalar::from_bytes_wide(bytes);
SecretKey(EmbeddedFr(value))
}
pub fn from_repr(bytes: &[u8; Self::BYTES]) -> CtOption<Self> {
let val = embedded::Scalar::from_bytes(bytes);
val.map(|scalar| SecretKey(EmbeddedFr(scalar)))
}
pub fn repr(&self) -> [u8; Self::BYTES] {
self.0.0.to_bytes()
}
pub fn public_key(&self) -> PublicKey {
PublicKey(EmbeddedGroupAffine::generator() * self.0)
}
pub fn decrypt<T: FromFieldRepr>(&self, ciph: &Ciphertext) -> Option<T> {
if ciph.c.is_identity() {
return None;
}
let k_star = ciph.c * self.0;
let coords = if k_star.is_infinity() {
(0.into(), 0.into())
} else {
(k_star.x().unwrap(), k_star.y().unwrap())
};
let k = transient_hash(&[coords.0, coords.1]);
let plain = ciph
.ciph
.iter()
.enumerate()
.map(|(ctr, ciph)| *ciph - transient_hash(&[k, (ctr as u64).into()]))
.collect::<Vec<_>>();
if plain.is_empty() || plain[0] != 0.into() {
debug!("zero element check in decryption failed");
return None;
}
T::from_field_repr(&plain[1..])
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "proptest")]
use proptest::prelude::*;
#[cfg(feature = "proptest")]
use rand::{SeedableRng, rngs::StdRng};
use super::*;
#[cfg(feature = "proptest")]
proptest! {
#[test]
fn correctness(
key in <SecretKey as Arbitrary>::arbitrary(),
msg in proptest::array::uniform32(proptest::num::u8::ANY)
) {
let mut rng = StdRng::from_seed([0x42; 32]);
let ciph = key.public_key().encrypt(&mut rng, &msg);
let dec = key.decrypt(&ciph);
assert_eq!(dec, Some(msg));
}
}
#[test]
fn secret_key_repr_roundtrip() {
let seeds: Vec<[u8; 64]> = vec![[0; 64], [1; 64], [255; 64]];
for seed in seeds {
let key = SecretKey::from_uniform_bytes(&seed);
let repr = key.repr();
let from_repr = SecretKey::from_repr(&repr).unwrap();
assert_eq!(from_repr.repr(), repr);
}
}
}