use aes_gcm::{KeyInit, aead::Aead};
use anyhow::{Context, Result, anyhow};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::algorithm::{Algorithm, Method};
const DB_VER: u32 = 3;
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Aegis {
Encrypted(AegisEncrypted),
Plaintext(AegisPlainText),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AegisPlainText {
version: u32,
header: Header,
db: Database,
}
impl Default for AegisPlainText {
fn default() -> Self {
Self { version: 1, header: Header { params: None, slots: None }, db: Default::default() }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AegisEncrypted {
version: u32,
header: Header,
db: String,
}
impl Default for Aegis {
fn default() -> Self {
Self::Plaintext(AegisPlainText::default())
}
}
impl Aegis {
pub fn add_entry(&mut self, entry: Entry) {
if let Self::Plaintext(plain_text) = self {
plain_text.db.entries.push(entry);
} else {
panic!("Trying to add an OTP entry to an encrypted aegis database")
}
}
pub fn encrypt(&mut self, password: &str) -> Result<()> {
let mut rng = rand::rng();
let mut master_key = [0u8; 32];
rng.fill_bytes(&mut master_key);
let mut header = Header { params: Some(HeaderParam::default()), slots: Some(vec![HeaderSlot::default()]) };
let password_slot = &mut header.slots.as_mut().unwrap().get_mut(0).unwrap();
let mut derived_key: [u8; 32] = [0u8; 32];
let params = scrypt::Params::new(
(password_slot.n() as f64).log2() as u8,
password_slot.r(),
password_slot.p(),
scrypt::Params::RECOMMENDED_LEN,
)
.expect("Scrypt params creation");
scrypt::scrypt(password.as_bytes(), password_slot.salt(), ¶ms, &mut derived_key).map_err(|_| anyhow::anyhow!("Scrypt key derivation"))?;
let cipher = match aes_gcm::Aes256Gcm::new_from_slice(&derived_key) {
Ok(c) => c,
Err(_) => return Err(anyhow!("Could not create cipher from key")),
};
let mut ciphertext: Vec<u8> = cipher
.encrypt(aes_gcm::Nonce::from_slice(&password_slot.key_params.nonce), master_key.as_ref())
.map_err(|_| anyhow::anyhow!("Encrypter master key"))?;
password_slot.key_params.tag = ciphertext.split_off(32).try_into().unwrap();
password_slot.key = ciphertext.try_into().unwrap();
if let Self::Plaintext(plain_text) = self {
let db_json: Vec<u8> = serde_json::ser::to_string_pretty(&plain_text.db)?.as_bytes().to_vec();
let cipher = match aes_gcm::Aes256Gcm::new_from_slice(&master_key) {
Ok(c) => c,
Err(_) => return Err(anyhow!("Could not create cipher from master key")),
};
let mut ciphertext: Vec<u8> = cipher
.encrypt(aes_gcm::Nonce::from_slice(&header.params.as_ref().unwrap().nonce), db_json.as_ref())
.map_err(|_| anyhow::anyhow!("Encrypting aegis database"))?;
header.params.as_mut().unwrap().tag = ciphertext.split_off(ciphertext.len() - 16).try_into().unwrap();
let db_encrypted = ciphertext;
*self = Self::Encrypted(AegisEncrypted { version: plain_text.version, header, db: data_encoding::BASE64.encode(&db_encrypted) });
} else {
panic!("Encrypt can only be called on a plaintext object.")
}
Ok(())
}
pub fn restore_from_data(from: &[u8], key: Option<&str>) -> Result<Vec<Entry>> {
let aegis_root: Aegis = serde_json::de::from_slice(from)?;
let mut entries = Vec::new();
match aegis_root {
Aegis::Plaintext(plain_text) => {
println!("Found unencrypted aegis vault with version {} and database version {}.", plain_text.version, plain_text.db.version);
if plain_text.version != 1 {
anyhow::bail!("Aegis vault version expected to be 1. Found {} instead.", plain_text.version);
} else if plain_text.db.version > 2 {
anyhow::bail!("Aegis database version expected to be 1 or 2. Found {} instead.", plain_text.db.version);
} else {
for mut entry in plain_text.db.entries {
entry.fix_empty_issuer()?;
entries.push(entry);
}
Ok(entries)
}
}
Aegis::Encrypted(encrypted) => {
println!("Found encrypted aegis vault with version {}.", encrypted.version);
if encrypted.version != 1 {
anyhow::bail!("Aegis vault version expected to be 1. Found {} instead.", encrypted.version);
} else if key.is_none() {
anyhow::bail!("Found encrypted aegis database but no password given.");
}
let mut ciphertext = data_encoding::BASE64.decode(encrypted.db.as_bytes()).context("Cannot decode (base64) encoded database")?;
ciphertext.append(&mut encrypted.header.params.as_ref().unwrap().tag.into());
let master_keys: Vec<Vec<u8>> = encrypted
.header
.slots
.as_ref()
.unwrap()
.iter()
.filter(|slot| slot.type_ == 1) .map(|slot| -> Result<Vec<u8>> {
println!("Found possible master key with UUID {}.", slot.uuid);
let params = scrypt::Params::new(
(slot.n() as f64).log2() as u8, slot.r(), slot.p(), scrypt::Params::RECOMMENDED_LEN,
)
.map_err(|_| anyhow::anyhow!("Invalid scrypt parameters"))?;
let mut temp_key: [u8; 32] = [0u8; 32];
scrypt::scrypt(key.unwrap().as_bytes(), slot.salt(), ¶ms, &mut temp_key)
.map_err(|_| anyhow::anyhow!("Scrypt key derivation failed"))?;
let cipher = match aes_gcm::Aes256Gcm::new_from_slice(&temp_key) {
Ok(c) => c,
Err(_) => return Err(anyhow!("Could not create cipher from key")),
};
let mut ciphertext: Vec<u8> = slot.key.to_vec();
ciphertext.append(&mut slot.key_params.tag.to_vec());
cipher
.decrypt(aes_gcm::Nonce::from_slice(&slot.key_params.nonce), ciphertext.as_ref())
.map_err(|_| anyhow::anyhow!("Cannot decrypt master key"))
})
.filter_map(|x| match x {
Ok(x) => Some(x),
Err(e) => {
println!("Decrypting master key failed: {:?}", e);
None
}
})
.collect();
println!("Found {} valid password slots / master keys.", master_keys.len());
let master_key = match master_keys.first() {
Some(x) => {
println!("Using only the first valid key slot / master key.");
x
}
None => anyhow::bail!("Did not find at least one slot with a valid key. Wrong password?"),
};
let cipher = match aes_gcm::Aes256Gcm::new_from_slice(master_key) {
Ok(c) => c,
Err(_) => return Err(anyhow!("Could not create cipher from key")),
};
let plaintext = cipher
.decrypt(aes_gcm::Nonce::from_slice(&encrypted.header.params.as_ref().unwrap().nonce), ciphertext.as_ref())
.map_err(|_| anyhow::anyhow!("Cannot decrypt database"))?;
let db: Database = serde_json::de::from_slice(&plaintext).context("Deserialize decrypted database failed")?;
println!("Found aegis database with version {}.", db.version);
for mut entry in db.entries {
entry.fix_empty_issuer()?;
entries.push(entry);
}
Ok(entries)
}
}
}
pub fn save(&mut self, destination: &mut dyn std::io::Write, password: &str) -> Result<()> {
self.encrypt(password)?;
let raw_encrypted_vault = serde_json::ser::to_string_pretty(&self)?;
destination.write_all(raw_encrypted_vault.as_bytes())?;
destination.flush()?;
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Header {
#[serde(default)]
pub slots: Option<Vec<HeaderSlot>>,
#[serde(default)]
pub params: Option<HeaderParam>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HeaderSlot {
#[serde(rename = "type")]
pub type_: u32,
pub uuid: String,
#[serde(with = "hex::serde")]
pub key: [u8; 32],
pub key_params: HeaderParam,
n: Option<u32>,
r: Option<u32>,
p: Option<u32>,
#[serde(default, with = "hex::serde")]
salt: [u8; 32],
}
impl HeaderSlot {
pub fn n(&self) -> u32 {
self.n.unwrap_or_else(|| 2_u32.pow(15))
}
pub fn r(&self) -> u32 {
self.r.unwrap_or(8)
}
pub fn p(&self) -> u32 {
self.p.unwrap_or(1)
}
pub fn salt(&self) -> &[u8; 32] {
&self.salt
}
}
impl Default for HeaderSlot {
fn default() -> Self {
let mut rng = rand::rng();
let mut salt = [0u8; 32];
rng.fill_bytes(&mut salt);
Self {
type_: 1,
uuid: uuid::Uuid::new_v4().to_string(),
key: [0u8; 32],
key_params: HeaderParam::default(),
n: Some(2_u32.pow(15)),
r: Some(8),
p: Some(1),
salt,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HeaderParam {
#[serde(with = "hex::serde")]
pub nonce: [u8; 12],
#[serde(with = "hex::serde")]
pub tag: [u8; 16],
}
impl Default for HeaderParam {
fn default() -> Self {
let mut rng = rand::rng();
let mut nonce = [0u8; 12];
rng.fill_bytes(&mut nonce);
Self { nonce, tag: [0u8; 16] }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Database {
pub version: u32,
pub entries: Vec<Entry>,
pub groups: Option<Vec<bool>>, }
impl Default for Database {
fn default() -> Self {
Self { version: DB_VER, entries: std::vec::Vec::new(), groups: None }
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Entry {
pub uuid: String,
#[serde(rename = "type")]
pub method: Method,
#[serde(rename = "name")]
pub label: String,
pub issuer: Option<String>,
#[serde(rename = "groups")]
pub tags: Option<String>,
#[serde(rename = "icon")]
pub thumbnail: Option<String>,
pub info: Detail,
}
impl Entry {
fn fix_empty_issuer(&mut self) -> Result<()> {
if self.issuer.is_none() {
let mut vals: Vec<&str> = self.label.split('@').collect();
if vals.len() > 1 {
self.issuer = vals.pop().map(ToOwned::to_owned);
self.label = vals.join("@");
} else {
anyhow::bail!("Entry {} has an empty issuer", self.label);
}
}
Ok(())
}
pub fn label(&self) -> String {
self.label.clone()
}
pub fn issuer(&self) -> String {
self.issuer.as_ref().map(ToOwned::to_owned).unwrap_or_default()
}
pub fn secret(&self) -> String {
self.info.secret.clone()
}
pub fn period(&self) -> Option<u32> {
self.info.period
}
pub fn method(&self) -> Method {
self.method
}
pub fn algorithm(&self) -> Algorithm {
self.info.algorithm
}
pub fn digits(&self) -> Option<u32> {
Some(self.info.digits)
}
pub fn counter(&self) -> Option<u32> {
self.info.counter
}
}
#[derive(Debug, Default, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
pub struct Detail {
pub secret: String,
#[serde(rename = "algo")]
#[zeroize(skip)]
pub algorithm: Algorithm,
#[zeroize(skip)]
pub digits: u32,
#[zeroize(skip)]
pub period: Option<u32>,
#[zeroize(skip)]
pub counter: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn issuer_from_name() {
let data = std::fs::read_to_string("./test_databases/aegis_issuer_from_name.json").unwrap();
let entries = Aegis::restore_from_data(data.as_bytes(), None).unwrap();
assert_eq!(entries[0].issuer(), "issuer");
assert_eq!(entries[0].label(), "missing-issuer");
assert_eq!(entries[1].issuer(), "issuer");
assert_eq!(entries[1].label(), "missing-issuer@domain.com");
}
#[test]
fn parse_plain() {
let data = std::fs::read_to_string("./test_databases/aegis_plain.json").unwrap();
let entries = Aegis::restore_from_data(data.as_bytes(), None).unwrap();
assert_eq!(entries[0].label(), "Bob");
assert_eq!(entries[0].issuer(), "Google");
assert_eq!(entries[0].secret(), "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567");
assert_eq!(entries[0].period(), Some(30));
assert_eq!(entries[0].algorithm(), Algorithm::SHA1);
assert_eq!(entries[0].digits(), Some(6));
assert_eq!(entries[0].counter(), None);
assert_eq!(entries[0].method(), Method::TOTP);
assert_eq!(entries[1].label(), "Benjamin");
assert_eq!(entries[1].issuer(), "Air Canada");
assert_eq!(entries[1].secret(), "KUVJJOM753IHTNDSZVCNKL7GII");
assert_eq!(entries[1].period(), None);
assert_eq!(entries[1].algorithm(), Algorithm::SHA256);
assert_eq!(entries[1].digits(), Some(7));
assert_eq!(entries[1].counter(), Some(50));
assert_eq!(entries[1].method(), Method::HOTP);
assert_eq!(entries[2].label(), "Sophia");
assert_eq!(entries[2].issuer(), "Boeing");
assert_eq!(entries[2].secret(), "JRZCL47CMXVOQMNPZR2F7J4RGI");
assert_eq!(entries[2].period(), Some(30));
assert_eq!(entries[2].algorithm(), Algorithm::SHA1);
assert_eq!(entries[2].digits(), Some(5));
assert_eq!(entries[2].counter(), None);
assert_eq!(entries[2].method(), Method::Steam);
}
#[test]
fn parse_encrypted() {
let data = std::fs::read_to_string("./test_databases/aegis_encrypted.json").unwrap();
let entries = Aegis::restore_from_data(data.as_bytes(), Some("test")).unwrap();
assert_eq!(entries[0].label(), "Mason");
assert_eq!(entries[0].issuer(), "Deno");
assert_eq!(entries[0].secret(), "4SJHB4GSD43FZBAI7C2HLRJGPQ");
assert_eq!(entries[0].period(), Some(30));
assert_eq!(entries[0].algorithm(), Algorithm::SHA1);
assert_eq!(entries[0].digits(), Some(6));
assert_eq!(entries[0].counter(), None);
assert_eq!(entries[0].method(), Method::TOTP);
assert_eq!(entries[3].label(), "James");
assert_eq!(entries[3].issuer(), "Issuu");
assert_eq!(entries[3].secret(), "YOOMIXWS5GN6RTBPUFFWKTW5M4");
assert_eq!(entries[3].period(), None);
assert_eq!(entries[3].algorithm(), Algorithm::SHA1);
assert_eq!(entries[3].digits(), Some(6));
assert_eq!(entries[3].counter(), Some(1));
assert_eq!(entries[3].method(), Method::HOTP);
assert_eq!(entries[6].label(), "Sophia");
assert_eq!(entries[6].issuer(), "Boeing");
assert_eq!(entries[6].secret(), "JRZCL47CMXVOQMNPZR2F7J4RGI");
assert_eq!(entries[6].period(), Some(30));
assert_eq!(entries[6].algorithm(), Algorithm::SHA1);
assert_eq!(entries[6].digits(), Some(5));
assert_eq!(entries[6].counter(), None);
assert_eq!(entries[6].method(), Method::Steam);
}
#[test]
fn encrypt() {
let mut aegis_root = Aegis::default();
let password = "my-super-secure-password";
let mut otp_entry = Entry::default();
otp_entry.method = Method::TOTP;
otp_entry.label = "Mason".to_string();
otp_entry.issuer = Some("Deno".to_string());
otp_entry.info.secret = "4SJHB4GSD43FZBAI7C2HLRJGPQ".to_string();
otp_entry.info.period = Some(30);
otp_entry.info.digits = 6;
otp_entry.info.counter = None;
aegis_root.add_entry(otp_entry);
let mut otp_entry = Entry::default();
otp_entry.method = Method::HOTP;
otp_entry.label = "James".to_string();
otp_entry.issuer = Some("Issuu".to_string());
otp_entry.info.secret = "YOOMIXWS5GN6RTBPUFFWKTW5M4".to_string();
otp_entry.info.algorithm = Algorithm::SHA1;
otp_entry.info.period = None;
otp_entry.info.digits = 6;
otp_entry.info.counter = Some(1);
aegis_root.add_entry(otp_entry);
aegis_root.encrypt(password).unwrap();
let raw_encrypted_vault = serde_json::ser::to_string_pretty(&aegis_root).unwrap();
let entries = Aegis::restore_from_data(raw_encrypted_vault.as_bytes(), Some(password)).unwrap();
assert_eq!(entries[0].label(), "Mason");
assert_eq!(entries[0].issuer(), "Deno");
assert_eq!(entries[0].secret(), "4SJHB4GSD43FZBAI7C2HLRJGPQ");
assert_eq!(entries[0].period(), Some(30));
assert_eq!(entries[0].algorithm(), Algorithm::SHA1);
assert_eq!(entries[0].digits(), Some(6));
assert_eq!(entries[0].counter(), None);
assert_eq!(entries[0].method(), Method::TOTP);
assert_eq!(entries[1].label(), "James");
assert_eq!(entries[1].issuer(), "Issuu");
assert_eq!(entries[1].secret(), "YOOMIXWS5GN6RTBPUFFWKTW5M4");
assert_eq!(entries[1].period(), None);
assert_eq!(entries[1].algorithm(), Algorithm::SHA1);
assert_eq!(entries[1].digits(), Some(6));
assert_eq!(entries[1].counter(), Some(1));
assert_eq!(entries[1].method(), Method::HOTP);
}
}