use base64::{engine::general_purpose, Engine as _};
use hmac::Hmac;
use pbkdf2::pbkdf2;
use rand::{rngs::OsRng, RngCore};
use sha2::Sha256;
use subtle::ConstantTimeEq as _;
use super::Argon2Parameters;
use super::DataType;
use super::Error;
use super::Header;
use super::Result;
pub fn generate_key(length: usize) -> Vec<u8> {
let mut key = vec![0u8; length];
OsRng.fill_bytes(&mut key);
key
}
pub fn derive_key_pbkdf2(key: &[u8], salt: &[u8], iterations: u32, length: usize) -> Vec<u8> {
let mut new_key = vec![0u8; length];
let _ = pbkdf2::<Hmac<Sha256>>(key, salt, iterations, &mut new_key);
new_key
}
pub fn derive_key_argon2(key: &[u8], parameters: &Argon2Parameters) -> Result<Vec<u8>> {
parameters.compute(key)
}
pub fn validate_header(data: &[u8], data_type: DataType) -> bool {
use super::ciphertext::Ciphertext;
use super::key::{PrivateKey, PublicKey};
use super::password_hash::PasswordHash;
use super::secret_sharing::Share;
use super::signature::Signature;
use super::signing_key::{SigningKeyPair, SigningPublicKey};
if data.len() < Header::len() {
return false;
}
match data_type {
DataType::None => false,
DataType::Ciphertext => Header::<Ciphertext>::try_from(&data[0..Header::len()]).is_ok(),
DataType::PasswordHash => Header::<PasswordHash>::try_from(&data[0..Header::len()]).is_ok(),
DataType::Key => {
Header::<PrivateKey>::try_from(&data[0..Header::len()]).is_ok()
|| Header::<PublicKey>::try_from(&data[0..Header::len()]).is_ok()
}
DataType::SigningKey => {
Header::<SigningKeyPair>::try_from(&data[0..Header::len()]).is_ok()
|| Header::<SigningPublicKey>::try_from(&data[0..Header::len()]).is_ok()
}
DataType::Share => Header::<Share>::try_from(&data[0..Header::len()]).is_ok(),
DataType::Signature => Header::<Signature>::try_from(&data[0..Header::len()]).is_ok(),
}
}
pub fn scrypt_simple(password: &[u8], salt: &[u8], log_n: u8, r: u32, p: u32) -> String {
use byteorder::{ByteOrder, LittleEndian};
use general_purpose::STANDARD;
let params = scrypt::Params::new(log_n, r, p, 32).expect("params should be valid");
let mut dk = [0u8; 32];
scrypt::scrypt(password, salt, ¶ms, &mut dk)
.expect("32 bytes always satisfy output length requirements");
let mut result = String::with_capacity(128);
result.push_str("$rscrypt$");
if r < 256 && p < 256 {
result.push_str("0$");
let mut tmp = [0u8; 3];
tmp[0] = log_n;
tmp[1] = r as u8;
tmp[2] = p as u8;
result.push_str(&STANDARD.encode(tmp));
} else {
result.push_str("1$");
let mut tmp = [0u8; 9];
tmp[0] = log_n;
LittleEndian::write_u32(&mut tmp[1..5], r);
LittleEndian::write_u32(&mut tmp[5..9], p);
result.push_str(&STANDARD.encode(tmp));
}
result.push('$');
result.push_str(&STANDARD.encode(salt));
result.push('$');
result.push_str(&STANDARD.encode(dk));
result.push('$');
result
}
pub fn base64_encode(data: &[u8]) -> String {
general_purpose::STANDARD.encode(data)
}
pub fn base64_encode_url(data: &[u8]) -> String {
general_purpose::URL_SAFE_NO_PAD.encode(data)
}
pub fn base64_decode(data: &str) -> Result<Vec<u8>> {
match general_purpose::STANDARD.decode(data) {
Ok(d) => Ok(d),
_ => Err(Error::InvalidData),
}
}
pub fn base64_decode_url(data: &str) -> Result<Vec<u8>> {
match general_purpose::URL_SAFE_NO_PAD.decode(data) {
Ok(d) => Ok(d),
_ => Err(Error::InvalidData),
}
}
pub fn constant_time_equals(x: &[u8], y: &[u8]) -> bool {
x.ct_eq(y).into()
}
#[test]
fn test_constant_time_equals() {
let x: [u8; 3] = [0, 1, 2];
let y: [u8; 3] = [4, 5, 6];
let z: [u8; 4] = [0, 1, 2, 3];
assert!(constant_time_equals(&x, &x));
assert!(!constant_time_equals(&x, &y));
assert!(!constant_time_equals(&x, &z));
assert!(!constant_time_equals(&y, &x));
assert!(constant_time_equals(&y, &y));
assert!(!constant_time_equals(&y, &z));
assert!(!constant_time_equals(&z, &x));
assert!(!constant_time_equals(&z, &y));
assert!(constant_time_equals(&z, &z));
}
#[test]
fn test_generate_key() {
let size = 32;
let key = generate_key(size);
assert_eq!(size, key.len());
assert_ne!(vec![0u8; size], key);
}
#[test]
fn test_derive_key_pbkdf2() {
let salt = b"salt";
let key = b"key";
let iterations = 100;
let size = 32;
let derived = derive_key_pbkdf2(key, salt, iterations, size);
assert_eq!(size, derived.len());
assert_ne!(vec![0u8; size], derived);
}
#[test]
fn test_validate_header() {
use general_purpose::STANDARD;
let valid_ciphertext = STANDARD.decode("DQwCAAAAAQA=").unwrap();
let valid_password_hash = STANDARD.decode("DQwDAAAAAQA=").unwrap();
let valid_share = STANDARD.decode("DQwEAAAAAQA=").unwrap();
let valid_private_key = STANDARD.decode("DQwBAAEAAQA=").unwrap();
let valid_public_key = STANDARD.decode("DQwBAAEAAQA=").unwrap();
assert!(validate_header(&valid_ciphertext, DataType::Ciphertext));
assert!(validate_header(
&valid_password_hash,
DataType::PasswordHash
));
assert!(validate_header(&valid_share, DataType::Share));
assert!(validate_header(&valid_private_key, DataType::Key));
assert!(validate_header(&valid_public_key, DataType::Key));
assert!(!validate_header(&valid_ciphertext, DataType::PasswordHash));
let invalid_signature = STANDARD.decode("DAwBAAEAAQA=").unwrap();
let invalid_type = STANDARD.decode("DQwIAAEAAQA=").unwrap();
let invalid_subtype = STANDARD.decode("DQwBAAgAAQA=").unwrap();
let invalid_version = STANDARD.decode("DQwBAAEACAA=").unwrap();
assert!(!validate_header(&invalid_signature, DataType::Key));
assert!(!validate_header(&invalid_type, DataType::Key));
assert!(!validate_header(&invalid_subtype, DataType::Key));
assert!(!validate_header(&invalid_version, DataType::Key));
let not_long_enough = STANDARD.decode("DQwBAAEAAQ==").unwrap();
assert!(!validate_header(¬_long_enough, DataType::Key));
}