use crate::types::ExportKeyFormat;
use rsa::{
pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey},
Oaep, Pss, RsaPrivateKey, RsaPublicKey,
};
use sha2::{Digest, Sha256, Sha384, Sha512};
const DEFAULT_RSA_KEY_SIZE: usize = 2048;
const DEFAULT_SALT_LENGTH: usize = 32;
#[derive(Clone)]
pub struct RsaKeyPair {
pub private_key: RsaPrivateKey,
pub public_key: RsaPublicKey,
pub hash_algorithm: HashAlgorithm,
}
#[derive(Clone, Copy)]
pub enum HashAlgorithm {
SHA256,
SHA384,
SHA512,
}
impl HashAlgorithm {
fn digest(&self, data: &[u8]) -> Vec<u8> {
match self {
HashAlgorithm::SHA256 => {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
HashAlgorithm::SHA384 => {
let mut hasher = Sha384::new();
hasher.update(data);
hasher.finalize().to_vec()
}
HashAlgorithm::SHA512 => {
let mut hasher = Sha512::new();
hasher.update(data);
hasher.finalize().to_vec()
}
}
}
fn oaep_padding(&self) -> Oaep {
match self {
HashAlgorithm::SHA256 => Oaep::new::<Sha256>(),
HashAlgorithm::SHA384 => Oaep::new::<Sha384>(),
HashAlgorithm::SHA512 => Oaep::new::<Sha512>(),
}
}
}
pub struct RSA;
impl RSA {
pub async fn generate_key_pair(
modulus_length: Option<usize>,
hash: Option<HashAlgorithm>,
) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
let mut rng = rand::thread_rng();
let bits = modulus_length.unwrap_or(DEFAULT_RSA_KEY_SIZE);
let private_key = RsaPrivateKey::new(&mut rng, bits)?;
let public_key = RsaPublicKey::from(&private_key);
let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
Ok(RsaKeyPair {
private_key,
public_key,
hash_algorithm,
})
}
pub async fn export_key_public(
key: &RsaKeyPair,
format: ExportKeyFormat,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
match format {
ExportKeyFormat::SPKI => Ok(key.public_key.to_public_key_der()?.as_bytes().to_vec()),
ExportKeyFormat::JWK => {
Err("JWK format not yet implemented".into())
}
_ => Err("Unsupported export format for public key".into()),
}
}
pub async fn export_key_private(
key: &RsaKeyPair,
format: ExportKeyFormat,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
match format {
ExportKeyFormat::PKCS8 => Ok(key.private_key.to_pkcs8_der()?.as_bytes().to_vec()),
ExportKeyFormat::JWK => {
Err("JWK format not yet implemented".into())
}
_ => Err("Unsupported export format for private key".into()),
}
}
pub async fn import_key(
key_data: &[u8],
format: ExportKeyFormat,
_for_encryption: bool,
hash: Option<HashAlgorithm>,
) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
match format {
ExportKeyFormat::SPKI => {
let public_key = RsaPublicKey::from_public_key_der(key_data)?;
let private_key =
RsaPrivateKey::new(&mut rand::thread_rng(), DEFAULT_RSA_KEY_SIZE)?;
Ok(RsaKeyPair {
private_key,
public_key,
hash_algorithm,
})
}
ExportKeyFormat::JWK => {
Err("JWK format not yet implemented".into())
}
_ => Err("Unsupported import format".into()),
}
}
pub async fn encrypt(
key: &RsaKeyPair,
data: impl AsRef<[u8]>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let mut rng = rand::thread_rng();
let padding = key.hash_algorithm.oaep_padding();
Ok(key.public_key.encrypt(&mut rng, padding, data.as_ref())?)
}
pub async fn decrypt(
key: &RsaKeyPair,
encrypted_data: impl AsRef<[u8]>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let padding = key.hash_algorithm.oaep_padding();
Ok(key.private_key.decrypt(padding, encrypted_data.as_ref())?)
}
pub async fn sign(
key: &RsaKeyPair,
data: impl AsRef<[u8]>,
salt_length: Option<usize>,
hash: Option<HashAlgorithm>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let hash_alg = hash.unwrap_or(key.hash_algorithm);
let hashed = hash_alg.digest(data.as_ref());
let mut rng = rand::thread_rng();
let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
let padding = match hash_alg {
HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
};
Ok(key.private_key.sign_with_rng(&mut rng, padding, &hashed)?)
}
pub async fn verify(
key: &RsaKeyPair,
signature: impl AsRef<[u8]>,
data: impl AsRef<[u8]>,
salt_length: Option<usize>,
hash: Option<HashAlgorithm>,
) -> Result<bool, Box<dyn std::error::Error>> {
let hash_alg = hash.unwrap_or(key.hash_algorithm);
let hashed = hash_alg.digest(data.as_ref());
let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
let padding = match hash_alg {
HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
};
Ok(key
.public_key
.verify(padding, &hashed, signature.as_ref())
.is_ok())
}
}