use std::fmt::{Display, Formatter};
use std::io::{Cursor, Write};
use aes_gcm::aead::{Aead, Nonce, OsRng};
use aes_gcm::aes::Aes128;
use aes_gcm::{AeadCore, Aes128Gcm, AesGcm, Key, KeyInit};
use anyhow::{bail, ensure, Result};
use clap::ValueEnum;
use deadpool_postgres::tokio_postgres::types::{FromSql, ToSql};
use md5::digest::consts::U12;
use rc4::{Rc4, StreamCipher};
use xor_utils::Xor;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Debug, Copy, Clone, Eq, PartialEq, ValueEnum, Hash, ToSql, FromSql)]
#[postgres(name = "encryptionkey_algorithm", rename_all = "lowercase")]
pub enum EncryptionOption {
AES128,
RC4,
Xor,
}
impl TryFrom<&str> for EncryptionOption {
type Error = anyhow::Error;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
match value {
"xor" => Ok(EncryptionOption::Xor),
"rc4" => Ok(EncryptionOption::RC4),
"aes128" => Ok(EncryptionOption::AES128),
_ => Err(anyhow::Error::msg(format!(
"Invalid encryption algorithm {value}"
))),
}
}
}
impl From<EncryptionOption> for FileEncryption {
fn from(option: EncryptionOption) -> Self {
let random_bytes = uuid::Uuid::new_v4().into_bytes().to_vec();
match option {
EncryptionOption::AES128 => FileEncryption::AES128(random_bytes),
EncryptionOption::RC4 => FileEncryption::RC4(random_bytes),
EncryptionOption::Xor => FileEncryption::Xor(random_bytes),
}
}
}
impl Display for EncryptionOption {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
EncryptionOption::Xor => write!(f, "Xor"),
EncryptionOption::RC4 => write!(f, "RC4"),
EncryptionOption::AES128 => write!(f, "AES-128"),
}
}
}
#[derive(Zeroize, ZeroizeOnDrop, Eq, PartialEq, Hash)]
pub enum FileEncryption {
AES128(Vec<u8>),
RC4(Vec<u8>),
Xor(Vec<u8>),
}
impl FileEncryption {
pub fn new(option: EncryptionOption, bytes: Vec<u8>) -> Result<Self> {
ensure!(bytes.len() == 16);
match option {
EncryptionOption::AES128 => Ok(FileEncryption::AES128(bytes)),
EncryptionOption::RC4 => Ok(FileEncryption::RC4(bytes)),
EncryptionOption::Xor => Ok(FileEncryption::Xor(bytes)),
}
}
#[must_use]
pub fn name(&self) -> &'static str {
match self {
FileEncryption::AES128(_) => "aes128",
FileEncryption::RC4(_) => "rc4",
FileEncryption::Xor(_) => "xor",
}
}
#[must_use]
pub fn key_type(&self) -> EncryptionOption {
match self {
FileEncryption::AES128(_) => EncryptionOption::AES128,
FileEncryption::RC4(_) => EncryptionOption::RC4,
FileEncryption::Xor(_) => EncryptionOption::Xor,
}
}
#[must_use]
pub fn key(&self) -> &[u8] {
match self {
FileEncryption::AES128(key) | FileEncryption::RC4(key) | FileEncryption::Xor(key) => {
key.as_ref()
}
}
}
pub fn decrypt(&self, data: &[u8], nonce: Option<Vec<u8>>) -> Result<Vec<u8>> {
match self {
FileEncryption::AES128(key) => {
if let Some(nonce) = nonce {
ensure!(nonce.len() == 12, "AES nonce but be 12 bytes");
let nonce = Nonce::<AesGcm<Aes128, U12>>::from_slice(&nonce);
let key = Key::<Aes128Gcm>::from_slice(key);
let cipher = Aes128Gcm::new(key);
let decrypted = cipher.decrypt(nonce, data)?;
Ok(decrypted)
} else {
bail!("Nonce required for AES");
}
}
FileEncryption::RC4(key) => {
use rc4::KeyInit;
let mut key = Rc4::new_from_slice(key)?;
let mut output = vec![0u8; data.len()];
key.apply_keystream_b2b(data, &mut output);
Ok(output)
}
FileEncryption::Xor(key) => {
let mut reader = Cursor::new(data.to_vec());
let result = reader.by_ref().xor(key);
Ok(result)
}
}
}
pub fn encrypt(&self, data: &[u8], nonce: Option<Vec<u8>>) -> Result<Vec<u8>> {
match self {
FileEncryption::AES128(key) => {
if let Some(nonce) = nonce {
let nonce = Nonce::<AesGcm<Aes128, U12>>::from_slice(&nonce);
let key = Key::<Aes128Gcm>::from_slice(key);
let cipher = Aes128Gcm::new(key);
let encrypted = cipher.encrypt(nonce, data)?;
Ok(encrypted)
} else {
bail!("Nonce required for AES");
}
}
FileEncryption::RC4(key) => {
use rc4::KeyInit;
let mut key = Rc4::new_from_slice(key)?;
let mut output = vec![0u8; data.len()];
key.apply_keystream_b2b(data, &mut output);
Ok(output)
}
FileEncryption::Xor(key) => {
let mut reader = Cursor::new(data.to_vec());
let result = reader.by_ref().xor(key);
Ok(result)
}
}
}
pub fn nonce(&self) -> Option<Vec<u8>> {
match self {
FileEncryption::AES128(_) => {
let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
Some(nonce.to_vec())
}
_ => None,
}
}
}
impl Display for FileEncryption {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[cfg(test)]
mod tests {
use super::{EncryptionOption, FileEncryption};
use malwaredb_types::utils::EntropyCalc;
use std::time::Instant;
use rstest::rstest;
#[rstest]
#[case::rc4(EncryptionOption::RC4)]
#[case::xor(EncryptionOption::Xor)]
#[case::aes128(EncryptionOption::AES128)]
#[test]
fn enc_dec(#[case] option: EncryptionOption) {
const BYTES: &[u8] = include_bytes!("../../types/testdata/exe/pe32_dotnet.exe");
let original_entropy = BYTES.entropy();
let encryptor = FileEncryption::from(option);
let start = Instant::now();
let nonce = encryptor.nonce();
let encrypted = encryptor.encrypt(BYTES, nonce.clone()).unwrap();
assert_ne!(BYTES, encrypted);
let encrypted_entropy = encrypted.entropy();
assert!(encrypted_entropy > original_entropy, "{option}: Encrypted entropy {encrypted_entropy} should be higher than the original entropy {original_entropy}");
if option != EncryptionOption::Xor {
assert!(
encrypted_entropy > 7.0,
"{option}: Entropy was {encrypted_entropy}, expected >7"
);
}
let decrypted = encryptor.decrypt(&encrypted, nonce).unwrap();
let duration = start.elapsed();
println!(
"{option} Time elapsed: {duration:?}, entropy increase: {:+.4}",
encrypted_entropy - original_entropy
);
assert_eq!(BYTES, decrypted);
}
}