use crate::{
core::{
partitions::Partition, Encapsulation, KeyEncapsulation, MasterSecretKey, PublicKey,
UserSecretKey,
},
Error,
};
use abe_policy::EncryptionHint;
use cosmian_crypto_core::{
asymmetric_crypto::DhKeyPair,
kdf,
reexport::rand_core::{CryptoRng, CryptoRngCore, RngCore},
symmetric_crypto::SymKey,
KeyTrait,
};
use pqc_kyber::{
indcpa::{indcpa_dec, indcpa_enc, indcpa_keypair},
KYBER_INDCPA_BYTES, KYBER_INDCPA_PUBLICKEYBYTES, KYBER_INDCPA_SECRETKEYBYTES, KYBER_SYMBYTES,
};
use std::{
collections::{HashMap, HashSet},
hash::Hash,
ops::{Add, Div, Mul, Sub},
};
pub(crate) const KEY_GEN_INFO: &[u8] = b"key generation info";
#[inline]
fn xor<const LENGTH: usize>(a: &[u8; LENGTH], b: &[u8; LENGTH]) -> [u8; LENGTH] {
let mut res = [0; LENGTH];
for (i, byte) in res.iter_mut().enumerate() {
*byte = a[i] ^ b[i];
}
res
}
pub fn setup<const PUBLIC_KEY_LENGTH: usize, const PRIVATE_KEY_LENGTH: usize, R, KeyPair>(
rng: &mut impl CryptoRngCore,
partitions: &HashMap<Partition, EncryptionHint>,
) -> (
MasterSecretKey<PRIVATE_KEY_LENGTH, KeyPair::PrivateKey>,
PublicKey<PUBLIC_KEY_LENGTH, KeyPair::PublicKey>,
)
where
KeyPair: DhKeyPair<PUBLIC_KEY_LENGTH, PRIVATE_KEY_LENGTH>,
KeyPair::PublicKey: From<KeyPair::PrivateKey>,
for<'a, 'b> &'a KeyPair::PublicKey: Add<&'b KeyPair::PublicKey, Output = KeyPair::PublicKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PublicKey>,
for<'a, 'b> &'a KeyPair::PrivateKey: Add<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Sub<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Div<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>,
{
let u = KeyPair::PrivateKey::new(rng);
let v = KeyPair::PrivateKey::new(rng);
let s = KeyPair::PrivateKey::new(rng);
let U = KeyPair::PublicKey::from(u.clone());
let V = KeyPair::PublicKey::from(v.clone());
let S = KeyPair::PublicKey::from(s.clone());
let mut x = HashMap::with_capacity(partitions.len());
let mut H = HashMap::with_capacity(partitions.len());
for (partition, &is_hybridized) in partitions {
let x_i = KeyPair::PrivateKey::new(rng);
let H_i = &S * &x_i;
let (sk_pq, pk_pq) = if is_hybridized == EncryptionHint::Hybridized {
let (mut sk, mut pk) = (
[0; KYBER_INDCPA_SECRETKEYBYTES],
[0; KYBER_INDCPA_PUBLICKEYBYTES],
);
indcpa_keypair(&mut pk, &mut sk, None, rng);
(Some(sk), Some(pk))
} else {
(None, None)
};
x.insert(partition.clone(), (sk_pq, x_i));
H.insert(partition.clone(), (pk_pq, H_i));
}
(MasterSecretKey { u, v, s, x }, PublicKey { U, V, H })
}
pub fn keygen<const PUBLIC_KEY_LENGTH: usize, const PRIVATE_KEY_LENGTH: usize, R, KeyPair>(
rng: &mut R,
msk: &MasterSecretKey<PRIVATE_KEY_LENGTH, KeyPair::PrivateKey>,
decryption_set: &HashSet<Partition>,
) -> UserSecretKey<PRIVATE_KEY_LENGTH, KeyPair::PrivateKey>
where
R: CryptoRng + RngCore,
KeyPair: DhKeyPair<PUBLIC_KEY_LENGTH, PRIVATE_KEY_LENGTH>,
KeyPair::PublicKey: From<KeyPair::PrivateKey>,
KeyPair::PrivateKey: Hash,
for<'a, 'b> &'a KeyPair::PublicKey: Add<&'b KeyPair::PublicKey, Output = KeyPair::PublicKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PublicKey>,
for<'a, 'b> &'a KeyPair::PrivateKey: Add<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Sub<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Div<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>,
{
let a = KeyPair::PrivateKey::new(rng);
let b = &(&msk.s - &(&a * &msk.u)) / &msk.v;
let x = decryption_set
.iter()
.filter_map(|partition| msk.x.get(partition))
.cloned()
.collect();
UserSecretKey { a, b, x }
}
pub fn encaps<
const TAG_LENGTH: usize,
const SYM_KEY_LENGTH: usize,
const PUBLIC_KEY_LENGTH: usize,
const PRIVATE_KEY_LENGTH: usize,
SymmetricKey,
KeyPair,
>(
rng: &mut impl CryptoRngCore,
mpk: &PublicKey<PUBLIC_KEY_LENGTH, KeyPair::PublicKey>,
encryption_set: &HashSet<Partition>,
) -> (
SymmetricKey,
Encapsulation<TAG_LENGTH, SYM_KEY_LENGTH, PUBLIC_KEY_LENGTH, KeyPair::PublicKey>,
)
where
SymmetricKey: SymKey<SYM_KEY_LENGTH>,
KeyPair: DhKeyPair<PUBLIC_KEY_LENGTH, PRIVATE_KEY_LENGTH>,
KeyPair::PublicKey: From<KeyPair::PrivateKey>,
for<'a, 'b> &'a KeyPair::PublicKey: Add<&'b KeyPair::PublicKey, Output = KeyPair::PublicKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PublicKey>,
for<'a, 'b> &'a KeyPair::PrivateKey: Add<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Sub<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Div<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>,
{
let mut K = [0; SYM_KEY_LENGTH];
rng.fill_bytes(&mut K);
let r = KeyPair::PrivateKey::new(rng);
let C = &mpk.U * &r;
let D = &mpk.V * &r;
let mut E = HashSet::with_capacity(encryption_set.len());
for partition in encryption_set {
if let Some((pk_i, H_i)) = mpk.H.get(partition) {
let E_i = xor(&K, &kdf!(SYM_KEY_LENGTH, &(H_i * &r).to_bytes()));
if let Some(pk_i) = pk_i {
let mut EPQ_i = [0; KYBER_INDCPA_BYTES];
indcpa_enc(&mut EPQ_i, &E_i, pk_i, &[0; KYBER_SYMBYTES]);
E.insert(KeyEncapsulation::HybridEncapsulation(Box::new(EPQ_i)));
} else {
E.insert(KeyEncapsulation::ClassicEncapsulation(Box::new(E_i)));
}
} }
let (tag, K) = eakem_hash!(TAG_LENGTH, SYM_KEY_LENGTH, &K, KEY_GEN_INFO);
(SymmetricKey::from_bytes(K), Encapsulation { C, D, tag, E })
}
pub fn decaps<
const TAG_LENGTH: usize,
const SYM_KEY_LENGTH: usize,
const PUBLIC_KEY_LENGTH: usize,
const PRIVATE_KEY_LENGTH: usize,
SymmetricKey,
KeyPair,
>(
usk: &UserSecretKey<PRIVATE_KEY_LENGTH, KeyPair::PrivateKey>,
encapsulation: &Encapsulation<
TAG_LENGTH,
SYM_KEY_LENGTH,
PUBLIC_KEY_LENGTH,
KeyPair::PublicKey,
>,
) -> Result<SymmetricKey, Error>
where
SymmetricKey: SymKey<SYM_KEY_LENGTH>,
KeyPair: DhKeyPair<PUBLIC_KEY_LENGTH, PRIVATE_KEY_LENGTH>,
KeyPair::PublicKey: From<KeyPair::PrivateKey>,
KeyPair::PrivateKey: Hash,
for<'a, 'b> &'a KeyPair::PublicKey: Add<&'b KeyPair::PublicKey, Output = KeyPair::PublicKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PublicKey>,
for<'a, 'b> &'a KeyPair::PrivateKey: Add<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Sub<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Div<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>,
{
let precomp = &(&encapsulation.C * &usk.a) + &(&encapsulation.D * &usk.b);
for encapsulation_i in &encapsulation.E {
for (sk_j, x_j) in &usk.x {
let E_j = match encapsulation_i {
KeyEncapsulation::HybridEncapsulation(EPQ_i) => {
if let Some(sk_j) = sk_j {
let mut E_j = [0; SYM_KEY_LENGTH];
indcpa_dec(&mut E_j, &**EPQ_i, sk_j);
E_j
} else {
continue;
}
}
KeyEncapsulation::ClassicEncapsulation(E_i) => **E_i,
};
let K = xor(&E_j, &kdf!(SYM_KEY_LENGTH, &(&precomp * x_j).to_bytes()));
let (tag, K) = eakem_hash!(TAG_LENGTH, SYM_KEY_LENGTH, &K, KEY_GEN_INFO);
if tag == encapsulation.tag {
return Ok(SymmetricKey::from_bytes(K));
}
}
}
Err(Error::InsufficientAccessPolicy)
}
pub fn update<const PUBLIC_KEY_LENGTH: usize, const PRIVATE_KEY_LENGTH: usize, R, KeyPair>(
rng: &mut impl CryptoRngCore,
msk: &mut MasterSecretKey<PRIVATE_KEY_LENGTH, KeyPair::PrivateKey>,
mpk: &mut PublicKey<PUBLIC_KEY_LENGTH, KeyPair::PublicKey>,
partitions_set: &HashMap<Partition, EncryptionHint>,
) -> Result<(), Error>
where
KeyPair: DhKeyPair<PUBLIC_KEY_LENGTH, PRIVATE_KEY_LENGTH>,
KeyPair::PublicKey: From<KeyPair::PrivateKey>,
for<'a, 'b> &'a KeyPair::PublicKey: Add<&'b KeyPair::PublicKey, Output = KeyPair::PublicKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PublicKey>,
for<'a, 'b> &'a KeyPair::PrivateKey: Add<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Sub<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Mul<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>
+ Div<&'b KeyPair::PrivateKey, Output = KeyPair::PrivateKey>,
{
let S = KeyPair::PublicKey::from(msk.s.clone());
let mut new_x = HashMap::with_capacity(partitions_set.len());
let mut new_H = HashMap::with_capacity(partitions_set.len());
for (partition, &is_hybridized) in partitions_set {
if let Some((sk_i, x_i)) = msk.x.get(partition) {
let H_i = &S * x_i;
let (sk_i, pk_i) = if is_hybridized == EncryptionHint::Hybridized {
let (pk_i, _) = mpk.H.get(partition).ok_or_else(|| {
Error::KeyError(
"Kyber public key cannot be computed from the secret key.".to_string(),
)
})?;
if sk_i.is_some() {
if pk_i.is_some() {
(*sk_i, *pk_i)
} else {
return Err(Error::KeyError(
"Kyber public key cannot be computed from the secret key.".to_string(),
));
}
} else {
let (mut sk_i, mut pk_i) = (
[0; KYBER_INDCPA_SECRETKEYBYTES],
[0; KYBER_INDCPA_PUBLICKEYBYTES],
);
indcpa_keypair(&mut pk_i, &mut sk_i, None, rng);
(Some(sk_i), Some(pk_i))
}
} else {
(None, None)
};
new_x.insert(partition.clone(), (sk_i, x_i.clone()));
new_H.insert(partition.clone(), (pk_i, H_i));
} else {
let x_i = KeyPair::PrivateKey::new(rng);
let H_i = &S * &x_i;
let (sk_pq, pk_pq) = if is_hybridized == EncryptionHint::Hybridized {
let (mut sk_pq, mut pk_pq) = (
[0; KYBER_INDCPA_SECRETKEYBYTES],
[0; KYBER_INDCPA_PUBLICKEYBYTES],
);
indcpa_keypair(&mut pk_pq, &mut sk_pq, None, rng);
(Some(sk_pq), Some(pk_pq))
} else {
(None, None)
};
new_x.insert(partition.clone(), (sk_pq, x_i));
new_H.insert(partition.clone(), (pk_pq, H_i));
}
}
msk.x = new_x;
mpk.H = new_H;
Ok(())
}
pub fn refresh<const PRIVATE_KEY_LENGTH: usize, PrivateKey>(
msk: &MasterSecretKey<PRIVATE_KEY_LENGTH, PrivateKey>,
usk: &mut UserSecretKey<PRIVATE_KEY_LENGTH, PrivateKey>,
decryption_set: &HashSet<Partition>,
keep_old_rights: bool,
) where
PrivateKey: KeyTrait<PRIVATE_KEY_LENGTH> + Hash,
{
if !keep_old_rights {
usk.x = Default::default();
}
for partition in decryption_set {
if let Some(x_i) = msk.x.get(partition) {
usk.x.insert(x_i.clone());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{decaps, encaps, keygen, refresh, setup, update};
use cosmian_crypto_core::{
asymmetric_crypto::curve25519::X25519KeyPair, reexport::rand_core::SeedableRng,
symmetric_crypto::aes_256_gcm_pure::Aes256GcmCrypto, CsRng,
};
const TAG_LENGTH: usize = 32;
const SYM_KEY_LENGTH: usize = 32;
type KeyPair = X25519KeyPair;
#[allow(clippy::upper_case_acronyms)]
type DEM = Aes256GcmCrypto;
#[test]
fn test_kyber() {
let mut rng = CsRng::from_entropy();
let keypair = pqc_kyber::keypair(&mut rng);
let (ct, ss) = pqc_kyber::encapsulate(&keypair.public, &mut rng).unwrap();
let res = pqc_kyber::decapsulate(&ct, &keypair.secret).unwrap();
assert_eq!(ss, res, "Decapsulation failed!");
}
#[test]
fn test_cover_crypt() -> Result<(), Error> {
let admin_partition = Partition(b"admin".to_vec());
let dev_partition = Partition(b"dev".to_vec());
let partitions_set = HashMap::from([
(admin_partition.clone(), EncryptionHint::Hybridized),
(dev_partition.clone(), EncryptionHint::Classic),
]);
let users_set = vec![
HashSet::from([dev_partition.clone()]),
HashSet::from([admin_partition.clone(), dev_partition.clone()]),
];
let admin_target_set = HashSet::from([admin_partition.clone()]);
let mut rng = CsRng::from_entropy();
let (mut msk, mut mpk) = setup!(&mut rng, &partitions_set);
let admin_secret_subkeys = msk.x.get(&admin_partition);
assert!(admin_secret_subkeys.is_some());
assert!(admin_secret_subkeys.unwrap().0.is_some());
let dev_secret_subkeys = msk.x.get(&dev_partition);
assert!(dev_secret_subkeys.is_some());
assert!(dev_secret_subkeys.unwrap().0.is_none());
let mut dev_usk = keygen!(&mut rng, &msk, &users_set[0]);
let admin_usk = keygen!(&mut rng, &msk, &users_set[1]);
let (sym_key, encapsulation) = encaps!(&mut rng, &mpk, &admin_target_set);
assert_eq!(encapsulation.E.len(), 1);
for key_encapsulation in &encapsulation.E {
if let KeyEncapsulation::ClassicEncapsulation(_) = key_encapsulation {
panic!("Wrong hybridization type");
}
}
let res0 = decaps!(&dev_usk, &encapsulation);
assert!(res0.is_err(), "User 0 shouldn't be able to decapsulate!");
let res1 = decaps!(&admin_usk, &encapsulation)?;
assert_eq!(sym_key, res1, "Wrong decapsulation for user 1!");
let client_partition = Partition(b"client".to_vec());
let new_partitions_set = HashMap::from([
(dev_partition.clone(), EncryptionHint::Hybridized),
(client_partition.clone(), EncryptionHint::Classic),
]);
let client_target_set = HashSet::from([client_partition.clone()]);
update!(&mut rng, &mut msk, &mut mpk, &new_partitions_set)?;
refresh!(
&msk,
&mut dev_usk,
&HashSet::from([dev_partition.clone()]),
false
);
let dev_secret_subkeys = msk.x.get(&dev_partition);
assert!(dev_secret_subkeys.is_some());
assert!(dev_secret_subkeys.unwrap().0.is_some());
let client_secret_subkeys = msk.x.get(&client_partition);
assert!(client_secret_subkeys.is_some());
assert!(client_secret_subkeys.unwrap().0.is_none());
assert_eq!(dev_usk.x.len(), 1);
for key_encapsulation in &encapsulation.E {
if let KeyEncapsulation::ClassicEncapsulation(_) = key_encapsulation {
panic!("Wrong hybridization type");
}
}
let (sym_key, new_encapsulation) = encaps!(&mut rng, &mpk, &client_target_set);
assert_eq!(new_encapsulation.E.len(), 1);
for key_encapsulation in &new_encapsulation.E {
if let KeyEncapsulation::HybridEncapsulation(_) = key_encapsulation {
panic!("Wrong hybridization type");
}
}
let res0 = decaps!(&dev_usk, &encapsulation);
assert!(
res0.is_err(),
"User 0 should not be able to decapsulate the old encapsulation."
);
let res1 = decaps!(&admin_usk, &new_encapsulation);
assert!(
res1.is_err(),
"User 1 should not be able to decapsulate the new encapsulation."
);
let client_usk = keygen!(&mut rng, &msk, &HashSet::from([client_partition]));
let res0 = decaps!(&client_usk, &new_encapsulation);
match res0 {
Err(err) => panic!("Client should be able to decapsulate: {err:?}"),
Ok(res) => assert_eq!(sym_key, res, "Wrong decapsulation."),
}
Ok(())
}
#[test]
fn test_master_keys_update() -> Result<(), Error> {
let partition_1 = Partition(b"1".to_vec());
let partition_2 = Partition(b"2".to_vec());
let partitions_set = HashMap::from([
(partition_1.clone(), EncryptionHint::Classic),
(partition_2.clone(), EncryptionHint::Hybridized),
]);
let mut rng = CsRng::from_entropy();
let (mut msk, mut mpk) = setup!(&mut rng, &partitions_set);
let partition_3 = Partition(b"3".to_vec());
let new_partitions_set = HashMap::from([
(partition_2.clone(), EncryptionHint::Hybridized),
(partition_3.clone(), EncryptionHint::Classic),
]);
update!(&mut rng, &mut msk, &mut mpk, &new_partitions_set)?;
assert!(!msk.x.contains_key(&partition_1));
assert!(msk.x.contains_key(&partition_2));
assert!(msk.x.contains_key(&partition_3));
assert!(!mpk.H.contains_key(&partition_1));
assert!(mpk.H.contains_key(&partition_2));
assert!(mpk.H.contains_key(&partition_3));
Ok(())
}
#[test]
fn test_user_key_refresh() -> Result<(), Error> {
let partition_1 = Partition(b"1".to_vec());
let partition_2 = Partition(b"2".to_vec());
let partition_3 = Partition(b"3".to_vec());
let partitions_set = HashMap::from([
(partition_1.clone(), EncryptionHint::Hybridized),
(partition_2.clone(), EncryptionHint::Hybridized),
(partition_3.clone(), EncryptionHint::Hybridized),
]);
let mut rng = CsRng::from_entropy();
let (mut msk, mut mpk) = setup!(&mut rng, &partitions_set);
let mut usk = keygen!(
&mut rng,
&msk,
&HashSet::from([partition_1.clone(), partition_2.clone()])
);
let partition_4 = Partition(b"4".to_vec());
let new_partition_set = HashMap::from([
(partition_2.clone(), EncryptionHint::Hybridized),
(partition_3.clone(), EncryptionHint::Classic),
(partition_4.clone(), EncryptionHint::Classic),
]);
let old_msk = msk.clone();
update!(&mut rng, &mut msk, &mut mpk, &new_partition_set)?;
refresh!(
&msk,
&mut usk,
&HashSet::from([partition_2.clone(), partition_4.clone()]),
false
);
assert!(!usk.x.contains(old_msk.x.get(&partition_1).unwrap()));
assert!(usk.x.contains(msk.x.get(&partition_2).unwrap()));
assert!(!usk.x.contains(old_msk.x.get(&partition_3).unwrap()));
assert!(usk.x.contains(msk.x.get(&partition_4).unwrap()));
Ok(())
}
}