use std::io;
use std::path::Path;
use rand::RngCore;
use thiserror::Error;
use ::rsa::traits::PublicKeyParts;
use ::rsa::RsaPrivateKey;
pub mod aes;
pub mod base64;
pub mod pem;
pub mod rsa;
pub use self::aes::{AES_BLOCK_SIZE, AES_KEYLEN};
pub use self::base64::{base64_decode, base64_encode};
#[derive(Debug, Error)]
pub enum CryptoError {
#[error("invalid key material")]
InvalidKey,
#[error("invalid PEM input: {0}")]
InvalidPem(String),
#[error("encryption failed")]
EncryptionFailed,
#[error("decryption failed")]
DecryptionFailed,
#[error("bad PKCS#7 padding")]
BadPadding,
#[error("base64 decode failed: {0}")]
Base64(String),
#[error(transparent)]
Io(#[from] io::Error),
}
pub struct Crypto {
aes_key: [u8; AES_KEYLEN],
rsa: RsaPrivateKey,
}
impl Crypto {
pub fn from_pem<P: AsRef<Path>>(path: P) -> Result<Self, CryptoError> {
let rsa = pem::load_rsa_private_key(path.as_ref())?;
let aes_key = Self::generate_aes_key()?;
Ok(Self { aes_key, rsa })
}
pub fn from_parts(rsa: RsaPrivateKey, aes_key: [u8; AES_KEYLEN]) -> Self {
Self { aes_key, rsa }
}
pub fn generate_aes_key() -> Result<[u8; AES_KEYLEN], CryptoError> {
let mut key = [0u8; AES_KEYLEN];
rand::rngs::OsRng.fill_bytes(&mut key);
Ok(key)
}
pub fn aes_key(&self) -> &[u8; AES_KEYLEN] {
&self.aes_key
}
pub fn rsa_private_key(&self) -> &RsaPrivateKey {
&self.rsa
}
pub fn rsa_size(&self) -> usize {
self.rsa.size()
}
pub fn aes_encrypt(msg: &[u8], aes_key: &[u8; AES_KEYLEN]) -> Result<Vec<u8>, CryptoError> {
aes::encrypt_to_vec(msg, aes_key)
}
pub fn aes_decrypt(enc: &[u8], aes_key: &[u8; AES_KEYLEN]) -> Result<Vec<u8>, CryptoError> {
aes::decrypt_to_vec(enc, aes_key)
}
pub fn dyn_aes_encrypt(
msg: &[u8],
aes_key: &[u8; AES_KEYLEN],
pool: &crate::io::mbuf::MbufPool,
) -> Result<crate::io::mbuf::MbufQueue, CryptoError> {
aes::encrypt_to_chain(msg, aes_key, pool)
}
pub fn dyn_aes_decrypt(
enc: &mut crate::io::mbuf::MbufQueue,
aes_key: &[u8; AES_KEYLEN],
pool: &crate::io::mbuf::MbufPool,
) -> Result<crate::io::mbuf::MbufQueue, CryptoError> {
aes::decrypt_chain_to_chain(enc, aes_key, pool)
}
pub fn dyn_aes_decrypt_to_vec(
enc: &mut crate::io::mbuf::MbufQueue,
aes_key: &[u8; AES_KEYLEN],
) -> Result<Vec<u8>, CryptoError> {
let mut bytes = Vec::with_capacity(enc.total_len());
while let Some(buf) = enc.pop_front() {
bytes.extend_from_slice(buf.readable());
}
Self::aes_decrypt(&bytes, aes_key)
}
pub fn dyn_aes_encrypt_msg(
msg: &crate::io::mbuf::Mbuf,
aes_key: &[u8; AES_KEYLEN],
pool: &crate::io::mbuf::MbufPool,
) -> Result<(crate::io::mbuf::MbufQueue, usize), CryptoError> {
let chain = aes::encrypt_to_chain(msg.readable(), aes_key, pool)?;
let n = chain.total_len();
Ok((chain, n))
}
pub fn rsa_encrypt(&self, msg: &[u8]) -> Result<Vec<u8>, CryptoError> {
rsa::encrypt(&self.rsa, msg)
}
pub fn rsa_decrypt(&self, enc: &[u8]) -> Result<Vec<u8>, CryptoError> {
rsa::decrypt(&self.rsa, enc)
}
}
impl std::fmt::Debug for Crypto {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Crypto")
.field("aes_key_len", &self.aes_key.len())
.field("rsa_size", &self.rsa_size())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_aes_key_returns_distinct_keys() {
let a = Crypto::generate_aes_key().unwrap();
let b = Crypto::generate_aes_key().unwrap();
assert_eq!(a.len(), AES_KEYLEN);
assert_ne!(a, b);
}
#[test]
fn aes_round_trip_short() {
let key = Crypto::generate_aes_key().unwrap();
for plain in &[&b""[..], b"a", b"abcdefghij", b"this is a test"] {
let cipher = Crypto::aes_encrypt(plain, &key).unwrap();
assert!(cipher.len() >= AES_BLOCK_SIZE);
assert_eq!(cipher.len() % AES_BLOCK_SIZE, 0);
let round = Crypto::aes_decrypt(&cipher, &key).unwrap();
assert_eq!(round.as_slice(), *plain);
}
}
#[test]
fn debug_does_not_leak_key() {
let aes = [0u8; AES_KEYLEN];
let mut rng = rand::rngs::OsRng;
let rsa = RsaPrivateKey::new(&mut rng, 2048).unwrap();
let c = Crypto::from_parts(rsa, aes);
let s = format!("{c:?}");
assert!(s.contains("Crypto"));
assert!(!s.contains("0, 0, 0, 0"));
}
}