use std::collections::HashSet;
use std::iter;
use std::time::Duration;
use age_core::{
format::{FileKey, Stanza, FILE_KEY_BYTES},
primitives::{aead_decrypt, aead_encrypt},
secrecy::{ExposeSecret, SecretString},
};
use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
use rand::{
distributions::{Alphanumeric, DistString},
rngs::OsRng,
RngCore,
};
use zeroize::Zeroize;
use crate::{
error::{DecryptError, EncryptError},
primitives::scrypt,
util::read::{base64_arg, decimal_digit_arg},
};
pub(super) const SCRYPT_RECIPIENT_TAG: &str = "scrypt";
const SCRYPT_SALT_LABEL: &[u8] = b"age-encryption.org/v1/scrypt";
const ONE_SECOND: Duration = Duration::from_secs(1);
const SALT_LEN: usize = 16;
const ENCRYPTED_FILE_KEY_BYTES: usize = FILE_KEY_BYTES + 16;
fn target_scrypt_work_factor() -> u8 {
let measure_duration = |log_n| {
#[cfg(not(all(target_arch = "wasm32", not(target_os = "wasi"))))]
{
use std::time::SystemTime;
let start = SystemTime::now();
scrypt(&[], log_n, "").expect("log_n < 64");
SystemTime::now().duration_since(start).ok()
}
#[cfg(all(target_arch = "wasm32", not(target_os = "wasi"), feature = "web-sys"))]
{
web_sys::window().and_then(|window| {
{ window.performance() }.map(|performance| {
let start = performance.now();
scrypt(&[], log_n, "").expect("log_n < 64");
Duration::from_secs_f64((performance.now() - start) / 1_000e0)
})
})
}
#[cfg(all(
target_arch = "wasm32",
not(target_os = "wasi"),
not(feature = "web-sys")
))]
{
None
}
};
let mut log_n = 10;
let mut duration: Option<Duration> = measure_duration(log_n);
while duration.map(|d| d.is_zero()).unwrap_or(false) {
log_n += 1;
duration = measure_duration(log_n);
}
duration
.map(|mut d| {
while d < ONE_SECOND && log_n < 63 {
log_n += 1;
d *= 2;
}
log_n
})
.unwrap_or({
18
})
}
pub struct Recipient {
passphrase: SecretString,
log_n: u8,
}
impl Recipient {
pub fn new(passphrase: SecretString) -> Self {
Self {
passphrase,
log_n: target_scrypt_work_factor(),
}
}
pub fn set_work_factor(&mut self, log_n: u8) {
assert!(0 < log_n && log_n < 64);
self.log_n = log_n;
}
}
impl crate::Recipient for Recipient {
fn wrap_file_key(
&self,
file_key: &FileKey,
) -> Result<(Vec<Stanza>, HashSet<String>), EncryptError> {
let mut rng = OsRng;
let mut salt = [0; SALT_LEN];
rng.fill_bytes(&mut salt);
let mut inner_salt = [0; SCRYPT_SALT_LABEL.len() + SALT_LEN];
inner_salt[..SCRYPT_SALT_LABEL.len()].copy_from_slice(SCRYPT_SALT_LABEL);
inner_salt[SCRYPT_SALT_LABEL.len()..].copy_from_slice(&salt);
let enc_key =
scrypt(&inner_salt, self.log_n, self.passphrase.expose_secret()).expect("log_n < 64");
let encrypted_file_key = aead_encrypt(&enc_key, file_key.expose_secret());
let encoded_salt = BASE64_STANDARD_NO_PAD.encode(salt);
let label = Alphanumeric.sample_string(&mut rng, 32);
Ok((
vec![Stanza {
tag: SCRYPT_RECIPIENT_TAG.to_owned(),
args: vec![encoded_salt, format!("{}", self.log_n)],
body: encrypted_file_key,
}],
iter::once(label).collect(),
))
}
}
pub struct Identity {
passphrase: SecretString,
target_work_factor: u8,
max_work_factor: u8,
}
impl Identity {
pub fn new(passphrase: SecretString) -> Self {
let target_work_factor = target_scrypt_work_factor();
let max_work_factor = target_work_factor + 4;
Self {
passphrase,
target_work_factor,
max_work_factor,
}
}
pub fn set_max_work_factor(&mut self, max_log_n: u8) {
self.max_work_factor = max_log_n;
}
}
impl crate::Identity for Identity {
fn unwrap_stanza(&self, stanza: &Stanza) -> Option<Result<FileKey, DecryptError>> {
if stanza.tag != SCRYPT_RECIPIENT_TAG {
return None;
}
let (salt, log_n) = match &stanza.args[..] {
[salt, log_n] => match (
base64_arg::<_, SALT_LEN, 18>(salt),
decimal_digit_arg(log_n),
) {
(Some(salt), Some(log_n)) => (salt, log_n),
_ => return Some(Err(DecryptError::InvalidHeader)),
},
_ => return Some(Err(DecryptError::InvalidHeader)),
};
if stanza.body.len() != ENCRYPTED_FILE_KEY_BYTES {
return Some(Err(DecryptError::InvalidHeader));
}
if log_n > self.max_work_factor {
return Some(Err(DecryptError::ExcessiveWork {
required: log_n,
target: self.target_work_factor,
}));
}
let mut inner_salt = [0; SCRYPT_SALT_LABEL.len() + SALT_LEN];
inner_salt[..SCRYPT_SALT_LABEL.len()].copy_from_slice(SCRYPT_SALT_LABEL);
inner_salt[SCRYPT_SALT_LABEL.len()..].copy_from_slice(&salt);
let enc_key = match scrypt(&inner_salt, log_n, self.passphrase.expose_secret()) {
Ok(k) => k,
Err(_) => {
return Some(Err(DecryptError::ExcessiveWork {
required: log_n,
target: self.target_work_factor,
}));
}
};
Some(
aead_decrypt(&enc_key, FILE_KEY_BYTES, &stanza.body)
.map(|mut pt| {
FileKey::init_with_mut(|file_key| {
file_key.copy_from_slice(&pt);
pt.zeroize();
})
})
.map_err(DecryptError::from),
)
}
}