use aes_gcm::{
aead::{Aead, KeyInit, Nonce},
Aes256Gcm,
};
use chacha20poly1305::ChaCha20Poly1305;
use rand::RngCore;
use sha2::{Digest, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::{CrablockError, Result};
pub const KEY_SIZE: usize = 32;
pub const NONCE_SIZE: usize = 12;
pub const TAG_SIZE: usize = 16;
pub const AES_GCM_NONCE_SIZE: usize = 12;
pub const CHACHA_NONCE_SIZE: usize = 12;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EncryptionAlgorithm {
Aes256Gcm,
ChaCha20Poly1305,
}
impl EncryptionAlgorithm {
pub fn nonce_size(&self) -> usize {
match self {
EncryptionAlgorithm::Aes256Gcm => AES_GCM_NONCE_SIZE,
EncryptionAlgorithm::ChaCha20Poly1305 => CHACHA_NONCE_SIZE,
}
}
pub fn as_str(&self) -> &'static str {
match self {
EncryptionAlgorithm::Aes256Gcm => "aes_256_gcm",
EncryptionAlgorithm::ChaCha20Poly1305 => "chacha20_poly1305",
}
}
}
impl std::str::FromStr for EncryptionAlgorithm {
type Err = CrablockError;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"aes_256_gcm" | "aes-256-gcm" | "aes256gcm" => Ok(EncryptionAlgorithm::Aes256Gcm),
"chacha20_poly1305" | "chacha20-poly1305" | "chacha20poly1305" => {
Ok(EncryptionAlgorithm::ChaCha20Poly1305)
}
_ => Err(CrablockError::UnsupportedAlgorithm(format!(
"Unknown algorithm: {s}"
))),
}
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct EncryptionKey {
pub key: [u8; KEY_SIZE],
}
impl EncryptionKey {
pub fn new(key: [u8; KEY_SIZE]) -> Self {
Self { key }
}
pub fn from_hex(hex_str: &str) -> Result<Self> {
let bytes = hex::decode(hex_str)
.map_err(|e| CrablockError::InvalidKey(format!("Invalid hex: {e}")))?;
if bytes.len() != KEY_SIZE {
return Err(CrablockError::InvalidKey(format!(
"Key must be {} bytes, got {}",
KEY_SIZE,
bytes.len()
)));
}
let mut key = [0u8; KEY_SIZE];
key.copy_from_slice(&bytes);
Ok(Self::new(key))
}
pub fn from_base64(b64_str: &str) -> Result<Self> {
use base64::Engine;
let bytes = base64::engine::general_purpose::STANDARD
.decode(b64_str)
.map_err(|e| CrablockError::InvalidKey(format!("Invalid base64: {e}")))?;
if bytes.len() != KEY_SIZE {
return Err(CrablockError::InvalidKey(format!(
"Key must be {} bytes, got {}",
KEY_SIZE,
bytes.len()
)));
}
let mut key = [0u8; KEY_SIZE];
key.copy_from_slice(&bytes);
Ok(Self::new(key))
}
pub fn generate_random() -> Self {
let mut key = [0u8; KEY_SIZE];
rand::thread_rng().fill_bytes(&mut key);
Self::new(key)
}
}
pub struct Encryptor {
algorithm: EncryptionAlgorithm,
key: EncryptionKey,
nonce: Vec<u8>,
}
impl Encryptor {
pub fn new(algorithm: EncryptionAlgorithm, key: EncryptionKey) -> Self {
let nonce_size = algorithm.nonce_size();
let mut nonce = vec![0u8; nonce_size];
rand::thread_rng().fill_bytes(&mut nonce);
Self {
algorithm,
key,
nonce,
}
}
pub fn with_nonce(mut self, nonce: Vec<u8>) -> Self {
self.nonce = nonce;
self
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let ciphertext = match self.algorithm {
EncryptionAlgorithm::Aes256Gcm => {
let cipher = Aes256Gcm::new_from_slice(&self.key.key)
.map_err(|e| CrablockError::Crypto(format!("AES key init failed: {e:?}")))?;
let nonce = Nonce::<Aes256Gcm>::from_slice(&self.nonce);
cipher
.encrypt(nonce, plaintext)
.map_err(|e| CrablockError::Crypto(format!("AES encryption failed: {e:?}")))?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
use chacha20poly1305::aead::Aead as ChaChaAead;
use chacha20poly1305::aead::KeyInit as ChaChaKeyInit;
use chacha20poly1305::Nonce as ChaChaNonce;
let cipher = ChaCha20Poly1305::new_from_slice(&self.key.key)
.map_err(|e| CrablockError::Crypto(format!("ChaCha key init failed: {e:?}")))?;
let nonce = ChaChaNonce::from_slice(&self.nonce);
cipher.encrypt(nonce, plaintext).map_err(|e| {
CrablockError::Crypto(format!("ChaCha encryption failed: {e:?}"))
})?
}
};
Ok(ciphertext)
}
pub fn nonce(&self) -> &[u8] {
&self.nonce
}
pub fn algorithm(&self) -> EncryptionAlgorithm {
self.algorithm
}
}
pub struct Decryptor {
algorithm: EncryptionAlgorithm,
key: EncryptionKey,
nonce: Vec<u8>,
}
impl Decryptor {
pub fn new(algorithm: EncryptionAlgorithm, key: EncryptionKey, nonce: Vec<u8>) -> Self {
Self {
algorithm,
key,
nonce,
}
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
let plaintext = match self.algorithm {
EncryptionAlgorithm::Aes256Gcm => {
let cipher = Aes256Gcm::new_from_slice(&self.key.key)
.map_err(|e| CrablockError::Crypto(format!("AES key init failed: {e:?}")))?;
let nonce = Nonce::<Aes256Gcm>::from_slice(&self.nonce);
cipher.decrypt(nonce, ciphertext).map_err(|e| {
CrablockError::DecryptionFailed(format!("AES decryption failed: {e:?}"))
})?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
use chacha20poly1305::aead::Aead as ChaChaAead;
use chacha20poly1305::aead::KeyInit as ChaChaKeyInit;
use chacha20poly1305::Nonce as ChaChaNonce;
let cipher = ChaCha20Poly1305::new_from_slice(&self.key.key)
.map_err(|e| CrablockError::Crypto(format!("ChaCha key init failed: {e:?}")))?;
let nonce = ChaChaNonce::from_slice(&self.nonce);
cipher.decrypt(nonce, ciphertext).map_err(|e| {
CrablockError::DecryptionFailed(format!("ChaCha decryption failed: {e:?}"))
})?
}
};
Ok(plaintext)
}
}
pub fn compute_sha256(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
hex::encode(hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aes_encryption_roundtrip() {
let key = EncryptionKey::generate_random();
let plaintext = b"Hello, World!";
let encryptor = Encryptor::new(EncryptionAlgorithm::Aes256Gcm, key.clone());
let nonce = encryptor.nonce().to_vec();
let ciphertext = encryptor.encrypt(plaintext).unwrap();
let decryptor = Decryptor::new(EncryptionAlgorithm::Aes256Gcm, key, nonce);
let decrypted = decryptor.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_chacha_encryption_roundtrip() {
let key = EncryptionKey::generate_random();
let plaintext = b"Hello, World!";
let encryptor = Encryptor::new(EncryptionAlgorithm::ChaCha20Poly1305, key.clone());
let nonce = encryptor.nonce().to_vec();
let ciphertext = encryptor.encrypt(plaintext).unwrap();
let decryptor = Decryptor::new(EncryptionAlgorithm::ChaCha20Poly1305, key, nonce);
let decrypted = decryptor.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_wrong_key_fails() {
let key1 = EncryptionKey::generate_random();
let key2 = EncryptionKey::generate_random();
let plaintext = b"Hello, World!";
let encryptor = Encryptor::new(EncryptionAlgorithm::Aes256Gcm, key1);
let nonce = encryptor.nonce().to_vec();
let ciphertext = encryptor.encrypt(plaintext).unwrap();
let decryptor = Decryptor::new(EncryptionAlgorithm::Aes256Gcm, key2, nonce);
let result = decryptor.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_sha256() {
let data = b"hello";
let hash = compute_sha256(data);
assert_eq!(hash.len(), 64); }
}