use crate::args::CodecType;
use aes::cipher::{
generic_array::{typenum::U16, GenericArray},
BlockDecrypt, BlockEncrypt, KeyInit,
};
use aes::Aes256;
use pbkdf2::pbkdf2_hmac_array;
use sha3::Keccak256;
use std::str;
const SALT: &[u8; 19] = b"salt-encrypt-rs-cli";
const PBKDF2_N: u32 = 50_000;
const BLOCK_SIZE: usize = 16;
const KEY_LEN: usize = 32;
pub struct Cipher {
cipher: Aes256,
}
impl Cipher {
pub fn new(password: &str) -> Cipher {
let key = Self::init_privkey(password);
let cipher = Self::init_cipher(&key);
Cipher { cipher }
}
fn init_privkey(password: &str) -> [u8; KEY_LEN] {
let bz = password.as_bytes();
pbkdf2_hmac_array::<Keccak256, KEY_LEN>(bz, SALT, PBKDF2_N)
}
fn init_cipher(bz: &[u8; KEY_LEN]) -> Aes256 {
Aes256::new_from_slice(bz).expect("key bytes should be {KEY_LEN}b size")
}
}
impl Cipher {
pub fn apply_codec(&self, mut content: Vec<u8>, codec: &CodecType) -> Vec<u8> {
let mut codec_bz: Vec<u8> = Vec::new();
let mut block: [u8; BLOCK_SIZE];
let mut apply_codec_block: GenericArray<u8, U16>;
while content.len() >= BLOCK_SIZE {
block = content[..BLOCK_SIZE].try_into().unwrap();
match codec {
CodecType::Encrypt => apply_codec_block = self.enc_block(block),
CodecType::Decrypt => apply_codec_block = self.dec_block(block),
}
codec_bz.extend_from_slice(apply_codec_block.as_mut_slice());
content = content[BLOCK_SIZE..].to_vec();
}
if !content.is_empty() {
if codec == &CodecType::Decrypt {
eprintln!("warning: invalid encrypted data found, data is likely corrupted");
}
let vec: Vec<u8> = content[..].to_vec();
block = Self::padded_block(vec);
let mut enc_block = self.enc_block(block);
codec_bz.extend_from_slice(enc_block.as_mut_slice());
}
if codec == &CodecType::Decrypt {
Self::strip_trailing_null(&mut codec_bz);
}
codec_bz
}
fn enc_block(&self, block: [u8; BLOCK_SIZE]) -> GenericArray<u8, U16> {
self.apply_codec_block(block, CodecType::Encrypt)
}
fn dec_block(&self, block: [u8; BLOCK_SIZE]) -> GenericArray<u8, U16> {
self.apply_codec_block(block, CodecType::Decrypt)
}
fn apply_codec_block(&self, bz: [u8; BLOCK_SIZE], codec: CodecType) -> GenericArray<u8, U16> {
let mut block = GenericArray::from(bz);
match codec {
CodecType::Encrypt => self.cipher.encrypt_block(&mut block),
CodecType::Decrypt => self.cipher.decrypt_block(&mut block),
}
block
}
fn padded_block(mut contents: Vec<u8>) -> [u8; BLOCK_SIZE] {
let required = BLOCK_SIZE - contents.len();
let mut padding = vec![0; required];
contents.append(&mut padding);
Self::vec_to_block(contents)
}
fn vec_to_block(vec: Vec<u8>) -> [u8; BLOCK_SIZE] {
vec.try_into()
.expect("expect array to be same length as block")
}
fn strip_trailing_null(content: &mut Vec<u8>) {
while !content.is_empty() && content[content.len() - 1] == 0u8 {
content.pop();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
Cipher::new("password");
}
#[test]
fn test_pbkdf2() {
let bz = Cipher::init_privkey("password");
assert_eq!(bz.len(), KEY_LEN);
}
#[test]
fn test_codec() {
let bz = Cipher::init_privkey("password");
Cipher::init_cipher(&bz);
}
#[test]
fn test_padded_block() {
let contents: Vec<u8> = vec![0; BLOCK_SIZE - 10];
let padded_contents = Cipher::padded_block(contents);
assert_eq!(padded_contents.len(), BLOCK_SIZE);
}
#[test]
#[should_panic]
fn test_padded_block_panic() {
let contents: Vec<u8> = vec![0; BLOCK_SIZE + 1];
Cipher::padded_block(contents);
}
#[test]
fn test_vec_to_block() {
let vec: Vec<u8> = (0..(BLOCK_SIZE as u8)).collect();
let arr = Cipher::vec_to_block(vec);
for (i, n) in arr.iter().enumerate() {
let i_u8 = i as u8;
assert_eq!(&i_u8, n);
}
}
#[test]
#[should_panic]
fn test_vec_to_block_panic() {
let vec: Vec<u8> = vec![0; BLOCK_SIZE + 1];
Cipher::vec_to_block(vec);
}
#[test]
fn test_block_cipher() {
let cipher = Cipher::new("password");
let plaintext = [42u8; BLOCK_SIZE];
let enc_block = cipher.enc_block(plaintext);
assert_ne!(&plaintext, enc_block.as_slice());
let mut enc_arr = [0u8; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
enc_arr[i] = enc_block[i];
}
let dec_block = cipher.dec_block(enc_arr);
assert_eq!(&plaintext, dec_block.as_slice());
}
#[test]
fn test_content_cipher_under_block() {
test_content_cipher("hello, world!");
}
#[test]
fn test_content_cipher_over_block() {
let text = "\
longer text, with multiple lines and
many bytes, to span multiple blocks
worth of encoding";
test_content_cipher(text);
}
fn test_content_cipher(content: &str) {
let cipher = Cipher::new("password");
let plain_bz = content.as_bytes().to_vec();
let enc_bz = cipher.apply_codec(plain_bz.clone(), &CodecType::Encrypt);
assert_ne!(plain_bz, enc_bz);
let dec_bz = cipher.apply_codec(enc_bz, &CodecType::Decrypt);
assert_eq!(plain_bz[..], dec_bz[..dec_bz.len()]);
}
#[test]
fn test_content_cipher_diff_passwords() {
let plain_bz = b"foo bar baz".to_vec();
let cipher1 = Cipher::new("foo");
let enc_bz1 = cipher1.apply_codec(plain_bz.clone(), &CodecType::Encrypt);
let cipher2 = Cipher::new("bar");
let enc_bz2 = cipher2.apply_codec(plain_bz.clone(), &CodecType::Encrypt);
assert_ne!(enc_bz1, enc_bz2);
}
}