use base64;
use hmac::Hmac;
use pbkdf2::pbkdf2;
use rand::{rngs::OsRng, RngCore};
use sha2::Sha256;
use super::DataType;
use super::Error;
use super::Header;
use super::Result;
pub fn generate_key(length: usize) -> Vec<u8> {
let mut key = vec![0u8; length];
OsRng.fill_bytes(&mut key);
key
}
pub fn derive_key(key: &[u8], salt: &[u8], iterations: usize, length: usize) -> Vec<u8> {
let mut new_key = vec![0u8; length];
pbkdf2::<Hmac<Sha256>>(&key, &salt, iterations, &mut new_key);
new_key
}
pub fn validate_header(data: &[u8], data_type: DataType) -> bool {
use super::ciphertext::Ciphertext;
use super::key::{PrivateKey, PublicKey};
use super::password_hash::PasswordHash;
use super::secret_sharing::Share;
use std::convert::TryFrom;
if data.len() < Header::len() {
return false;
}
match data_type {
DataType::None => false,
DataType::Ciphertext => Header::<Ciphertext>::try_from(&data[0..Header::len()]).is_ok(),
DataType::PasswordHash => Header::<PasswordHash>::try_from(&data[0..Header::len()]).is_ok(),
DataType::Key => {
Header::<PrivateKey>::try_from(&data[0..Header::len()]).is_ok()
|| Header::<PublicKey>::try_from(&data[0..Header::len()]).is_ok()
}
DataType::Share => Header::<Share>::try_from(&data[0..Header::len()]).is_ok(),
}
}
pub fn base64_encode(data: &[u8]) -> String {
base64::encode(data)
}
pub fn base64_encode_url(data: &[u8]) -> String {
let config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
base64::encode_config(data, config)
}
pub fn base64_decode(data: &str) -> Result<Vec<u8>> {
match base64::decode(data) {
Ok(d) => Ok(d),
_ => Err(Error::InvalidData),
}
}
pub fn base64_decode_url(data: &str) -> Result<Vec<u8>> {
let config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
match base64::decode_config(data, config) {
Ok(d) => Ok(d),
_ => Err(Error::InvalidData),
}
}
#[test]
fn test_generate_key() {
let size = 32;
let key = generate_key(size);
assert_eq!(size, key.len());
assert_ne!(vec![0u8; size], key);
}
#[test]
fn test_derive_key() {
let salt = b"salt";
let key = b"key";
let iterations = 100;
let size = 32;
let derived = derive_key(key, salt, iterations, size);
assert_eq!(size, derived.len());
assert_ne!(vec![0u8; size], derived);
}
#[test]
fn test_validate_header() {
use base64::decode;
let valid_ciphertext = decode("DQwCAAAAAQA=").unwrap();
let valid_password_hash = decode("DQwDAAAAAQA=").unwrap();
let valid_share = decode("DQwEAAAAAQA=").unwrap();
let valid_private_key = decode("DQwBAAEAAQA=").unwrap();
let valid_public_key = decode("DQwBAAEAAQA=").unwrap();
assert!(validate_header(&valid_ciphertext, DataType::Ciphertext));
assert!(validate_header(
&valid_password_hash,
DataType::PasswordHash
));
assert!(validate_header(&valid_share, DataType::Share));
assert!(validate_header(&valid_private_key, DataType::Key));
assert!(validate_header(&valid_public_key, DataType::Key));
assert!(!validate_header(&valid_ciphertext, DataType::PasswordHash));
let invalid_signature = decode("DAwBAAEAAQA=").unwrap();
let invalid_type = decode("DQwIAAEAAQA=").unwrap();
let invalid_subtype = decode("DQwBAAgAAQA=").unwrap();
let invalid_version = decode("DQwBAAEACAA=").unwrap();
assert!(!validate_header(&invalid_signature, DataType::Key));
assert!(!validate_header(&invalid_type, DataType::Key));
assert!(!validate_header(&invalid_subtype, DataType::Key));
assert!(!validate_header(&invalid_version, DataType::Key));
let not_long_enough = decode("DQwBAAEAAQ==").unwrap();
assert!(!validate_header(¬_long_enough, DataType::Key));
}