use crate::algorithms::{
asymmetric::kem::{KemAlgorithm, KyberSecurityLevel, RsaBits},
hash::HashAlgorithm,
};
use crate::define_wrapper;
use crate::error::{Error, FormatError, Result};
use crate::keys::asymmetric::kem::{EncapsulatedKey, SharedSecret};
use crate::keys::asymmetric::kem::{TypedKemKeyPair, TypedKemPrivateKey, TypedKemPublicKey};
use crate::keys::asymmetric::{TypedAsymmetricPrivateKeyTrait, TypedAsymmetricPublicKeyTrait};
use crate::traits::KemAlgorithmTrait;
use seal_crypto::prelude::{AsymmetricKeySet, Kem, Key};
use seal_crypto::schemes::asymmetric::post_quantum::kyber::{Kyber512, Kyber768, Kyber1024};
use seal_crypto::schemes::asymmetric::traditional::rsa::{Rsa2048, Rsa4096};
use seal_crypto::schemes::hash::{Sha256, Sha384, Sha512};
use std::ops::Deref;
macro_rules! impl_kem_algorithm {
($wrapper:ident, $algo:ty, $algo_enum:expr) => {
define_wrapper!(@unit_struct, $wrapper, KemAlgorithmTrait, {
fn algorithm(&self) -> KemAlgorithm {
$algo_enum
}
fn encapsulate_key(
&self,
public_key: &TypedKemPublicKey,
) -> Result<(SharedSecret, EncapsulatedKey)> {
if public_key.algorithm != $algo_enum {
return Err(Error::FormatError(FormatError::InvalidKeyType));
}
type KT = $algo;
let pk = <KT as AsymmetricKeySet>::PublicKey::from_bytes(&public_key.to_bytes())?;
KT::encapsulate(&pk).map_err(Error::from).map(|(shared_secret, ciphertext)| (SharedSecret(shared_secret), EncapsulatedKey { key: ciphertext, algorithm: $algo_enum }))
}
fn decapsulate_key(
&self,
private_key: &TypedKemPrivateKey,
encapsulated_key: &EncapsulatedKey,
) -> Result<SharedSecret> {
if private_key.algorithm != $algo_enum || encapsulated_key.algorithm != $algo_enum {
return Err(Error::FormatError(FormatError::InvalidKeyType));
}
type KT = $algo;
let sk = <KT as AsymmetricKeySet>::PrivateKey::from_bytes(&private_key.to_bytes())?;
KT::decapsulate(&sk, &encapsulated_key.key).map_err(Error::from).map(SharedSecret)
}
fn generate_keypair(&self) -> Result<TypedKemKeyPair> {
TypedKemKeyPair::generate($algo_enum)
}
fn clone_box(&self) -> Box<dyn KemAlgorithmTrait> {
Box::new(self.clone())
}
fn into_boxed(self) -> Box<dyn KemAlgorithmTrait> {
Box::new(self)
}
});
};
}
#[derive(Debug, Clone)]
pub struct KemAlgorithmWrapper {
pub(crate) algorithm: Box<dyn KemAlgorithmTrait>,
}
impl Deref for KemAlgorithmWrapper {
type Target = Box<dyn KemAlgorithmTrait>;
fn deref(&self) -> &Self::Target {
&self.algorithm
}
}
impl Into<Box<dyn KemAlgorithmTrait>> for KemAlgorithmWrapper {
fn into(self) -> Box<dyn KemAlgorithmTrait> {
self.algorithm
}
}
impl KemAlgorithmWrapper {
pub fn new(algorithm: Box<dyn KemAlgorithmTrait>) -> Self {
Self { algorithm }
}
pub fn from_enum(algorithm: KemAlgorithm) -> Self {
let algorithm: Box<dyn KemAlgorithmTrait> = match algorithm {
KemAlgorithm::Rsa(RsaBits::B2048, HashAlgorithm::Sha256) => {
Box::new(Rsa2048Sha256Wrapper::new())
}
KemAlgorithm::Rsa(RsaBits::B2048, HashAlgorithm::Sha384) => {
Box::new(Rsa2048Sha384Wrapper::new())
}
KemAlgorithm::Rsa(RsaBits::B2048, HashAlgorithm::Sha512) => {
Box::new(Rsa2048Sha512Wrapper::new())
}
KemAlgorithm::Rsa(RsaBits::B4096, HashAlgorithm::Sha256) => {
Box::new(Rsa4096Sha256Wrapper::new())
}
KemAlgorithm::Rsa(RsaBits::B4096, HashAlgorithm::Sha384) => {
Box::new(Rsa4096Sha384Wrapper::new())
}
KemAlgorithm::Rsa(RsaBits::B4096, HashAlgorithm::Sha512) => {
Box::new(Rsa4096Sha512Wrapper::new())
}
KemAlgorithm::Kyber(KyberSecurityLevel::L512) => Box::new(Kyber512Wrapper::new()),
KemAlgorithm::Kyber(KyberSecurityLevel::L768) => Box::new(Kyber768Wrapper::new()),
KemAlgorithm::Kyber(KyberSecurityLevel::L1024) => Box::new(Kyber1024Wrapper::new()),
};
Self::new(algorithm)
}
pub fn generate_keypair(&self) -> Result<TypedKemKeyPair> {
self.algorithm.generate_keypair()
}
}
impl KemAlgorithmTrait for KemAlgorithmWrapper {
fn algorithm(&self) -> KemAlgorithm {
self.algorithm.algorithm()
}
fn encapsulate_key(
&self,
public_key: &TypedKemPublicKey,
) -> Result<(SharedSecret, EncapsulatedKey)> {
self.algorithm.encapsulate_key(public_key)
}
fn decapsulate_key(
&self,
private_key: &TypedKemPrivateKey,
encapsulated_key: &EncapsulatedKey,
) -> Result<SharedSecret> {
self.algorithm
.decapsulate_key(private_key, encapsulated_key)
}
fn generate_keypair(&self) -> Result<TypedKemKeyPair> {
self.algorithm.generate_keypair()
}
fn clone_box(&self) -> Box<dyn KemAlgorithmTrait> {
self.algorithm.clone_box()
}
fn into_boxed(self) -> Box<dyn KemAlgorithmTrait> {
self.algorithm
}
}
impl From<KemAlgorithm> for KemAlgorithmWrapper {
fn from(algorithm: KemAlgorithm) -> Self {
Self::from_enum(algorithm)
}
}
impl From<Box<dyn KemAlgorithmTrait>> for KemAlgorithmWrapper {
fn from(algorithm: Box<dyn KemAlgorithmTrait>) -> Self {
Self::new(algorithm)
}
}
impl_kem_algorithm!(
Rsa2048Sha256Wrapper,
Rsa2048<Sha256>,
KemAlgorithm::Rsa(RsaBits::B2048, HashAlgorithm::Sha256)
);
impl_kem_algorithm!(
Rsa2048Sha384Wrapper,
Rsa2048<Sha384>,
KemAlgorithm::Rsa(RsaBits::B2048, HashAlgorithm::Sha384)
);
impl_kem_algorithm!(
Rsa2048Sha512Wrapper,
Rsa2048<Sha512>,
KemAlgorithm::Rsa(RsaBits::B2048, HashAlgorithm::Sha512)
);
impl_kem_algorithm!(
Rsa4096Sha256Wrapper,
Rsa4096<Sha256>,
KemAlgorithm::Rsa(RsaBits::B4096, HashAlgorithm::Sha256)
);
impl_kem_algorithm!(
Rsa4096Sha384Wrapper,
Rsa4096<Sha384>,
KemAlgorithm::Rsa(RsaBits::B4096, HashAlgorithm::Sha384)
);
impl_kem_algorithm!(
Rsa4096Sha512Wrapper,
Rsa4096<Sha512>,
KemAlgorithm::Rsa(RsaBits::B4096, HashAlgorithm::Sha512)
);
impl_kem_algorithm!(
Kyber512Wrapper,
Kyber512,
KemAlgorithm::Kyber(KyberSecurityLevel::L512)
);
impl_kem_algorithm!(
Kyber768Wrapper,
Kyber768,
KemAlgorithm::Kyber(KyberSecurityLevel::L768)
);
impl_kem_algorithm!(
Kyber1024Wrapper,
Kyber1024,
KemAlgorithm::Kyber(KyberSecurityLevel::L1024)
);