use std::ops::DerefMut;
use blake2b_simd::Params;
use openssl::rand::rand_priv_bytes;
use serde::Serialize;
use serde::{Deserialize, de::DeserializeOwned};
use thiserror::Error;
use zeroize::{Zeroize, Zeroizing};
use crate::modules::crypto::{AEADCipher, AES, AESKeySize, BlockCipher, CipherMode};
#[derive(Debug, Error)]
pub enum CryptoError {
#[error("Crypto error: {0}")]
Custom(String),
#[error("Some serde_json error happened, {:?}", .source)]
SerdeJson {
#[from]
source: serde_json::Error,
},
#[error("Some openssl error happened, {:?}", .source)]
OpenSSL {
#[from]
source: openssl::error::ErrorStack,
},
#[error("Some libvault error happened, {:?}", .source)]
RvError {
#[from]
source: crate::errors::RvError,
},
}
type Result<T, E = CryptoError> = std::result::Result<T, E>;
#[derive(Default, Serialize, Deserialize, Zeroize)]
#[zeroize(drop)]
pub struct CryptoKey {
key: Vec<u8>,
aad: Vec<u8>,
}
#[derive(Zeroize)]
#[zeroize(drop)]
pub struct EncryptedBox<T> {
ciphertext: Vec<u8>,
key: CryptoKey,
#[zeroize(skip)]
_marker: std::marker::PhantomData<T>,
}
impl CryptoKey {
pub fn new() -> Self {
let mut key = Zeroizing::new(vec![0u8; 32]);
let _ = rand_priv_bytes(key.deref_mut().as_mut_slice());
let mut aad = Zeroizing::new(vec![0u8; 16]);
let _ = rand_priv_bytes(aad.deref_mut().as_mut_slice());
Self {
key: key.to_vec(),
aad: aad.to_vec(),
}
}
pub fn encrypt<T: Serialize + DeserializeOwned>(&self, value: &T) -> Result<Vec<u8>> {
let plaintext = serde_json::to_vec(value)?;
let mut nonce = vec![0u8; 16];
let _ = rand_priv_bytes(&mut nonce);
let mut aes_encrypter = AES::new(
false,
Some(AESKeySize::AES256),
Some(CipherMode::GCM),
Some(self.key.clone()),
Some(nonce.clone()),
)?;
aes_encrypter.set_aad(self.aad.clone())?;
let ciphertext = aes_encrypter.encrypt(&plaintext)?;
let tag = aes_encrypter.get_tag()?;
let mut result = vec![];
result.extend_from_slice(&nonce);
result.extend_from_slice(&tag);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt<T: Serialize + DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
if value.len() < 32 {
return Err(CryptoError::Custom("Invalid ciphertext length".to_string()));
}
let nonce = value[0..16].to_vec();
let tag = value[16..32].to_vec();
let mut aes_decrypter = AES::new(
false,
Some(AESKeySize::AES256),
Some(CipherMode::GCM),
Some(self.key.clone()),
Some(nonce),
)?;
aes_decrypter.set_aad(self.aad.clone())?;
aes_decrypter.set_tag(tag)?;
let plaintext = aes_decrypter.decrypt(&value[32..].to_vec())?;
Ok(serde_json::from_slice(&plaintext)?)
}
}
impl<T> EncryptedBox<T>
where
T: Serialize + DeserializeOwned,
{
pub fn new(value: &T) -> Result<Self> {
let key = CryptoKey::new();
let ciphertext = key.encrypt(value)?;
Ok(Self {
ciphertext,
key,
_marker: std::marker::PhantomData,
})
}
pub fn get(&self) -> Result<T> {
let value: T = self.key.decrypt(&self.ciphertext)?;
Ok(value)
}
}
pub fn blake2b256_hash(key: &str) -> Vec<u8> {
let hash = Params::new()
.hash_length(32)
.to_state()
.update(key.as_bytes())
.finalize();
hash.as_bytes().to_vec()
}