use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::Aes256Gcm;
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use scrypt::scrypt;
use serde::Deserialize;
use zeroize::{Zeroize, Zeroizing};
use crate::bulk::BulkError;
const SCRYPT_MAX_N: u64 = 1 << 22;
const SCRYPT_MAX_R: u32 = 64;
const SCRYPT_MAX_P: u32 = 8;
const SCRYPT_MAX_MEM_BYTES: u64 = 256 * 1024 * 1024;
const MAX_PASSWORD_SLOTS: usize = 8;
#[derive(Deserialize)]
struct Root {
header: Header,
db: String,
}
#[derive(Deserialize)]
struct Header {
slots: Vec<Slot>,
params: AeadParams,
}
#[derive(Deserialize)]
struct Slot {
#[serde(rename = "type")]
typ: u32,
key: String,
key_params: AeadParams,
n: Option<u64>,
r: Option<u32>,
p: Option<u32>,
salt: Option<String>,
}
#[derive(Deserialize)]
struct AeadParams {
nonce: String,
tag: String,
}
fn hex_decode(s: &str, label: &'static str) -> Result<Vec<u8>, BulkError> {
keyroost_proto::codec::hex_decode(s).map_err(|_| BulkError::UnsupportedFormat(label))
}
pub fn decrypt_aegis(json: &str, password: &[u8]) -> Result<Zeroizing<String>, BulkError> {
let root: Root = serde_json::from_str(json)?;
let password_slots: Vec<&Slot> = root.header.slots.iter().filter(|s| s.typ == 1).collect();
if password_slots.is_empty() {
return Err(BulkError::UnsupportedFormat(
"Aegis vault has no password slot (biometric/keystore not supported)",
));
}
let mut last_err: Option<BulkError> = None;
for slot in password_slots.into_iter().take(MAX_PASSWORD_SLOTS) {
match try_unlock_slot(slot, password, &root.header.params, &root.db) {
Ok(plaintext) => return Ok(plaintext),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or(BulkError::UnsupportedFormat("Aegis decrypt failed")))
}
fn try_unlock_slot(
slot: &Slot,
password: &[u8],
db_params: &AeadParams,
db_b64: &str,
) -> Result<Zeroizing<String>, BulkError> {
let salt = hex_decode(
slot.salt
.as_deref()
.ok_or(BulkError::UnsupportedFormat("slot missing salt"))?,
"slot salt",
)?;
let n = slot
.n
.ok_or(BulkError::UnsupportedFormat("slot missing n"))?;
let r = slot
.r
.ok_or(BulkError::UnsupportedFormat("slot missing r"))?;
let p = slot
.p
.ok_or(BulkError::UnsupportedFormat("slot missing p"))?;
if !n.is_power_of_two() || n < 2 {
return Err(BulkError::UnsupportedFormat(
"slot n is not a valid power of 2",
));
}
if n > SCRYPT_MAX_N
|| r > SCRYPT_MAX_R
|| p > SCRYPT_MAX_P
|| 128 * n * u64::from(r) > SCRYPT_MAX_MEM_BYTES
{
return Err(BulkError::UnsupportedFormat(
"slot scrypt parameters exceed sanity caps",
));
}
let log_n = n.trailing_zeros() as u8;
let params = scrypt::Params::new(log_n, r, p, 32)
.map_err(|_| BulkError::UnsupportedFormat("invalid scrypt params"))?;
let mut kek = Zeroizing::new([0u8; 32]);
scrypt(password, &salt, ¶ms, kek.as_mut())
.map_err(|_| BulkError::UnsupportedFormat("scrypt failed"))?;
let slot_nonce = hex_decode(&slot.key_params.nonce, "slot nonce")?;
let slot_tag = hex_decode(&slot.key_params.tag, "slot tag")?;
let slot_ct = hex_decode(&slot.key, "slot key ciphertext")?;
let master_key = Zeroizing::new(
gcm_decrypt(kek.as_ref(), &slot_nonce, &slot_ct, &slot_tag)
.map_err(|()| BulkError::UnsupportedFormat("wrong password (slot did not decrypt)"))?,
);
if master_key.len() != 32 {
return Err(BulkError::UnsupportedFormat(
"decrypted master key is not 32 bytes",
));
}
let db_nonce = hex_decode(&db_params.nonce, "db nonce")?;
let db_tag = hex_decode(&db_params.tag, "db tag")?;
let db_ct = B64
.decode(db_b64.as_bytes())
.map_err(|_| BulkError::UnsupportedFormat("db is not valid base64"))?;
let plaintext = gcm_decrypt(&master_key, &db_nonce, &db_ct, &db_tag)
.map_err(|()| BulkError::UnsupportedFormat("db did not decrypt with master key"))?;
let mut inner = Zeroizing::new(String::from_utf8(plaintext).map_err(|e| {
let mut bytes = e.into_bytes();
bytes.zeroize();
BulkError::UnsupportedFormat("decrypted db is not UTF-8")
})?);
let mut wrapped = Zeroizing::new(String::with_capacity(inner.len() + 8));
wrapped.push_str(r#"{"db":"#);
wrapped.push_str(&inner);
wrapped.push('}');
inner.zeroize();
Ok(wrapped)
}
fn gcm_decrypt(key: &[u8], nonce: &[u8], ct: &[u8], tag: &[u8]) -> Result<Vec<u8>, ()> {
if key.len() != 32 {
return Err(());
}
if nonce.len() != 12 {
return Err(());
}
if tag.len() != 16 {
return Err(());
}
let cipher = Aes256Gcm::new_from_slice(key).map_err(|_| ())?;
let mut buf = Vec::with_capacity(ct.len() + tag.len());
buf.extend_from_slice(ct);
buf.extend_from_slice(tag);
cipher
.decrypt(
nonce.into(),
Payload {
msg: &buf,
aad: b"",
},
)
.map_err(|_| ())
}