use std::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
use clear_on_drop::clear::Clear;
use nimiq_database_value_derive::DbSerializable;
use nimiq_hash::argon2kdf::{compute_argon2_kdf, Argon2Error, Argon2Variant};
use nimiq_serde::{Deserialize, Serialize};
use rand::{rngs::OsRng, RngCore as _, TryRngCore as _};
pub fn otp(
secret: &[u8],
password: &[u8],
iterations: u32,
salt: &[u8],
algorithm: Algorithm,
) -> Result<Vec<u8>, Argon2Error> {
let mut key = compute_argon2_kdf(password, salt, iterations, secret.len(), algorithm.into())?;
assert_eq!(key.len(), secret.len());
for (key_byte, secret_byte) in key.iter_mut().zip(secret.iter()) {
*key_byte ^= secret_byte;
}
Ok(key)
}
pub trait Verify {
fn verify(&self) -> bool;
}
struct ClearOnDrop<T: Clear> {
place: Option<T>,
}
impl<T: Clear> ClearOnDrop<T> {
#[inline]
fn new(place: T) -> Self {
ClearOnDrop { place: Some(place) }
}
#[inline]
fn into_uncleared_place(mut c: Self) -> T {
c.place.take().unwrap()
}
}
impl<T: Clear> Drop for ClearOnDrop<T> {
#[inline]
fn drop(&mut self) {
if let Some(ref mut data) = self.place {
data.clear();
}
}
}
impl<T: Clear> Deref for ClearOnDrop<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
self.place.as_ref().unwrap()
}
}
impl<T: Clear> DerefMut for ClearOnDrop<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.place.as_mut().unwrap()
}
}
impl<T: Clear> AsRef<T> for ClearOnDrop<T> {
#[inline]
fn as_ref(&self) -> &T {
self.place.as_ref().unwrap()
}
}
pub struct Unlocked<T: Clear + Deserialize + Serialize> {
data: ClearOnDrop<T>,
lock: Locked<T>,
}
impl<T: Clear + Deserialize + Serialize> Unlocked<T> {
pub fn new(
secret: T,
password: &[u8],
iterations: u32,
salt_length: usize,
algorithm: Algorithm,
) -> Result<Self, Argon2Error> {
let locked = Locked::create(&secret, password, iterations, salt_length, algorithm)?;
Ok(Unlocked {
data: ClearOnDrop::new(secret),
lock: locked,
})
}
pub fn with_defaults(secret: T, password: &[u8]) -> Result<Self, Argon2Error> {
Self::new(
secret,
password,
OtpLock::<T>::DEFAULT_ITERATIONS,
OtpLock::<T>::DEFAULT_SALT_LENGTH,
Algorithm::default(),
)
}
#[inline]
pub fn lock(lock: Self) -> Locked<T> {
lock.lock
}
#[inline]
pub fn into_otp_lock(lock: Self) -> OtpLock<T> {
OtpLock::Unlocked(lock)
}
#[inline]
pub fn into_unlocked_data(lock: Self) -> T {
ClearOnDrop::into_uncleared_place(lock.data)
}
#[inline]
pub fn unlocked_data(lock: &Self) -> &T {
&lock.data
}
}
impl<T: Clear + Deserialize + Serialize> Deref for Unlocked<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub enum Algorithm {
Argon2d = 0,
#[default]
Argon2id = 2,
}
impl Algorithm {
pub fn backwards_compatible_default() -> Algorithm {
Self::Argon2d
}
}
impl From<Algorithm> for Argon2Variant {
fn from(value: Algorithm) -> Self {
match value {
Algorithm::Argon2d => Argon2Variant::Argon2d,
Algorithm::Argon2id => Argon2Variant::Argon2id,
}
}
}
#[derive(Serialize, Deserialize, DbSerializable)]
pub struct Locked<T: Clear + Deserialize + Serialize> {
lock: Vec<u8>,
salt: Vec<u8>,
iterations: u32,
#[serde(default = "Algorithm::backwards_compatible_default")]
algorithm: Algorithm,
phantom: PhantomData<T>,
}
impl<T: Clear + Deserialize + Serialize> Locked<T> {
pub fn new(
mut secret: T,
password: &[u8],
iterations: u32,
salt_length: usize,
algorithm: Algorithm,
) -> Result<Self, Argon2Error> {
let result = Locked::create(&secret, password, iterations, salt_length, algorithm)?;
secret.clear();
Ok(result)
}
pub fn with_defaults(secret: T, password: &[u8]) -> Result<Self, Argon2Error> {
Self::new(
secret,
password,
OtpLock::<T>::DEFAULT_ITERATIONS,
OtpLock::<T>::DEFAULT_SALT_LENGTH,
Algorithm::default(),
)
}
pub fn unlock_unchecked(self, password: &[u8]) -> Result<Unlocked<T>, Locked<T>> {
let key_opt = otp(
&self.lock,
password,
self.iterations,
&self.salt,
self.algorithm,
)
.ok();
let mut key = if let Some(key_content) = key_opt {
key_content
} else {
return Err(self);
};
let result = T::deserialize_from_vec(&key).ok();
for byte in key.iter_mut() {
byte.clear();
}
if let Some(data) = result {
Ok(Unlocked {
data: ClearOnDrop::new(data),
lock: self,
})
} else {
Err(self)
}
}
fn lock(
secret: &T,
password: &[u8],
iterations: u32,
salt: Vec<u8>,
algorithm: Algorithm,
) -> Result<Self, Argon2Error> {
let mut data = secret.serialize_to_vec();
let lock = otp(&data, password, iterations, &salt, algorithm)?;
for byte in data.iter_mut() {
byte.clear();
}
Ok(Locked {
lock,
salt,
iterations,
algorithm,
phantom: PhantomData,
})
}
fn create(
secret: &T,
password: &[u8],
iterations: u32,
salt_length: usize,
algorithm: Algorithm,
) -> Result<Self, Argon2Error> {
let mut salt = vec![0; salt_length];
OsRng.unwrap_err().fill_bytes(salt.as_mut_slice());
Self::lock(secret, password, iterations, salt, algorithm)
}
pub fn into_otp_lock(self) -> OtpLock<T> {
OtpLock::Locked(self)
}
}
impl<T: Clear + Deserialize + Serialize + Verify> Locked<T> {
pub fn unlock(self, password: &[u8]) -> Result<Unlocked<T>, Locked<T>> {
let unlocked = self.unlock_unchecked(password);
match unlocked {
Ok(unlocked) => {
if unlocked.verify() {
Ok(unlocked)
} else {
Err(unlocked.lock)
}
}
err => err,
}
}
}
pub enum OtpLock<T: Clear + Deserialize + Serialize> {
Unlocked(Unlocked<T>),
Locked(Locked<T>),
}
impl<T: Clear + Deserialize + Serialize> OtpLock<T> {
pub const DEFAULT_SALT_LENGTH: usize = 32;
pub const DEFAULT_ITERATIONS: u32 = 3;
pub fn new_unlocked(
secret: T,
password: &[u8],
iterations: u32,
salt_length: usize,
algorithm: Algorithm,
) -> Result<Self, Argon2Error> {
Ok(OtpLock::Unlocked(Unlocked::new(
secret,
password,
iterations,
salt_length,
algorithm,
)?))
}
pub fn unlocked_with_defaults(secret: T, password: &[u8]) -> Result<Self, Argon2Error> {
Self::new_unlocked(
secret,
password,
Self::DEFAULT_ITERATIONS,
Self::DEFAULT_SALT_LENGTH,
Algorithm::default(),
)
}
pub fn new_locked(
secret: T,
password: &[u8],
iterations: u32,
salt_length: usize,
algorithm: Algorithm,
) -> Result<Self, Argon2Error> {
Ok(OtpLock::Locked(Locked::new(
secret,
password,
iterations,
salt_length,
algorithm,
)?))
}
pub fn locked_with_defaults(secret: T, password: &[u8]) -> Result<Self, Argon2Error> {
Self::new_locked(
secret,
password,
Self::DEFAULT_ITERATIONS,
Self::DEFAULT_SALT_LENGTH,
Algorithm::default(),
)
}
#[inline]
pub fn is_locked(&self) -> bool {
matches!(self, OtpLock::Locked(_))
}
#[inline]
pub fn is_unlocked(&self) -> bool {
!self.is_locked()
}
#[inline]
#[must_use]
pub fn lock(self) -> Self {
match self {
OtpLock::Unlocked(unlocked) => OtpLock::Locked(Unlocked::lock(unlocked)),
l => l,
}
}
#[inline]
pub fn locked(self) -> Locked<T> {
match self {
OtpLock::Unlocked(unlocked) => Unlocked::lock(unlocked),
OtpLock::Locked(locked) => locked,
}
}
#[inline]
pub fn unlocked(self) -> Result<Unlocked<T>, Self> {
match self {
OtpLock::Unlocked(unlocked) => Ok(unlocked),
l => Err(l),
}
}
#[inline]
pub fn unlocked_ref(&self) -> Option<&Unlocked<T>> {
match self {
OtpLock::Unlocked(unlocked) => Some(unlocked),
_ => None,
}
}
}