use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use argon2::{self, Argon2};
use rand::RngCore;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
ffi::{OsStr, OsString},
fmt,
fs::{self, File},
io::{self, Read, Write},
ops::{Deref, DerefMut},
path::{Path, PathBuf},
sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
};
use tracing::{error, info};
const SALT_LEN: usize = 16;
const NONCE_LEN: usize = 12;
#[derive(Serialize, Deserialize)]
struct EncryptedData {
salt: Vec<u8>,
nonce: Vec<u8>,
ciphertext: Vec<u8>,
}
pub trait EncryptedDataStore: Default + Serialize {
fn open<P>(db: P, password: &str) -> io::Result<EncryptedAtomicDatabase<Self>>
where
P: AsRef<Path>,
Self: DeserializeOwned,
{
let db_path = db.as_ref();
if db_path.exists() {
EncryptedAtomicDatabase::load(db_path, password)
} else {
EncryptedAtomicDatabase::create(db_path, password)
}
}
fn load_encrypted(file: impl Read, key: &Key<Aes256Gcm>) -> io::Result<Self>
where
Self: Sized,
Self: DeserializeOwned,
{
let encrypted: EncryptedData = serde_json::from_reader(file).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to deserialize encrypted data: {}", e),
)
})?;
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&encrypted.nonce);
let decrypted_bytes = cipher
.decrypt(nonce, encrypted.ciphertext.as_ref())
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Decryption failed: {}", e),
)
})?;
let data = serde_json::from_slice(&decrypted_bytes).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to deserialize decrypted data: {}", e),
)
})?;
Ok(data)
}
fn save_encrypted(&self, file: impl Write, key: &Key<Aes256Gcm>, salt: &[u8]) -> io::Result<()>
where
Self: Serialize,
{
let mut nonce_bytes = vec![0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&nonce_bytes);
let plaintext = serde_json::to_vec(self).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Serialization failed: {}", e),
)
})?;
let ciphertext = cipher.encrypt(nonce, plaintext.as_ref()).map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("Encryption failed: {}", e))
})?;
let encrypted = EncryptedData {
salt: salt.to_vec(), nonce: nonce_bytes,
ciphertext,
};
serde_json::to_writer(file, &encrypted).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed to write encrypted data to file: {}", e),
)
})
}
}
fn derive_key(password: &str, salt: &[u8]) -> io::Result<Key<Aes256Gcm>> {
let mut key = [0u8; 32]; Argon2::default()
.hash_password_into(password.as_bytes(), salt, &mut key)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Key derivation failed"))?;
Ok(*Key::<Aes256Gcm>::from_slice(&key))
}
pub struct EncryptedAtomicDatabase<T: EncryptedDataStore> {
path: PathBuf,
tmp: PathBuf,
data: RwLock<T>,
key: RwLock<Key<Aes256Gcm>>,
salt: RwLock<Vec<u8>>,
}
impl<T: EncryptedDataStore + DeserializeOwned> EncryptedAtomicDatabase<T> {
pub fn load<P: AsRef<Path>>(path: P, password: &str) -> io::Result<Self> {
let new_path = path.as_ref().to_path_buf();
let tmp = Self::tmp_path(&new_path)?;
let file = File::open(&new_path)?;
let encrypted: EncryptedData = serde_json::from_reader(&file).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to deserialize encrypted data: {}", e),
)
})?;
let key = derive_key(password, &encrypted.salt)?;
let file = File::open(&new_path)?;
let data = T::load_encrypted(file, &key)?;
Ok(Self {
path: new_path,
tmp,
data: RwLock::new(data),
key: RwLock::new(key),
salt: RwLock::new(encrypted.salt),
})
}
pub fn create<P: AsRef<Path>>(path: P, password: &str) -> io::Result<Self> {
let new_path = path.as_ref().to_path_buf();
let tmp = Self::tmp_path(&new_path)?;
let mut salt = vec![0u8; SALT_LEN];
OsRng.fill_bytes(&mut salt);
let key = derive_key(password, &salt)?;
let data = Default::default();
atomic_write_encrypted(&tmp, &new_path, &data, &key, &salt)?;
Ok(Self {
path: new_path,
tmp,
data: RwLock::new(data),
key: RwLock::new(key),
salt: RwLock::new(salt),
})
}
pub fn read(&self) -> EncryptedAtomicDatabaseRead<'_, T> {
EncryptedAtomicDatabaseRead {
data: self.data.read().unwrap(),
}
}
pub fn write(&self) -> EncryptedAtomicDatabaseWrite<'_, T> {
let key = *self.key.read().unwrap();
let salt = self.salt.read().unwrap().clone();
EncryptedAtomicDatabaseWrite {
path: self.path.as_ref(),
tmp: self.tmp.as_ref(),
data: self.data.write().unwrap(),
key,
salt,
}
}
pub fn change_password(&self, new_password: &str) -> io::Result<()> {
let data_guard = self.data.read().unwrap();
let mut new_salt = vec![0u8; SALT_LEN];
OsRng.fill_bytes(&mut new_salt);
let new_key = derive_key(new_password, &new_salt)?;
atomic_write_encrypted(&self.tmp, &self.path, &*data_guard, &new_key, &new_salt)?;
{
let mut key_lock = self.key.write().unwrap();
*key_lock = new_key;
}
{
let mut salt_lock = self.salt.write().unwrap();
*salt_lock = new_salt;
}
Ok(())
}
fn tmp_path(path: &Path) -> io::Result<PathBuf> {
let mut tmp_name = OsString::from(".");
tmp_name.push(path.file_name().unwrap_or(OsStr::new("db")));
tmp_name.push("~");
let tmp = path.with_file_name(tmp_name);
if tmp.exists() {
error!(
"Found orphaned database temporary file '{tmp:?}'. The server has recently crashed or is already running. Delete this before continuing!"
);
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
"Orphaned temporary file exists",
));
}
Ok(tmp)
}
}
fn atomic_write_encrypted<T: EncryptedDataStore>(
tmp: &Path,
path: &Path,
data: &T,
key: &Key<Aes256Gcm>,
salt: &[u8],
) -> io::Result<()> {
{
let tmpfile = File::create(tmp)?;
data.save_encrypted(tmpfile, key, salt)?;
}
fs::rename(tmp, path)?;
Ok(())
}
impl<T: EncryptedDataStore> fmt::Debug for EncryptedAtomicDatabase<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EncryptedAtomicDatabase")
.field("file", &self.path)
.finish()
}
}
impl<T: EncryptedDataStore> Drop for EncryptedAtomicDatabase<T> {
fn drop(&mut self) {
info!("Saving database");
let data_guard = self.data.read().unwrap();
let key = self.key.read().unwrap();
let salt = self.salt.read().unwrap();
if let Err(e) = atomic_write_encrypted(&self.tmp, &self.path, &*data_guard, &key, &salt) {
error!("Failed to save database: {}", e);
}
}
}
pub struct EncryptedAtomicDatabaseRead<'a, T: EncryptedDataStore> {
data: RwLockReadGuard<'a, T>,
}
impl<'a, T: EncryptedDataStore> Deref for EncryptedAtomicDatabaseRead<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
pub struct EncryptedAtomicDatabaseWrite<'a, T: EncryptedDataStore> {
tmp: &'a Path,
path: &'a Path,
data: RwLockWriteGuard<'a, T>,
key: Key<Aes256Gcm>,
salt: Vec<u8>,
}
impl<'a, T: EncryptedDataStore> Deref for EncryptedAtomicDatabaseWrite<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<'a, T: EncryptedDataStore> DerefMut for EncryptedAtomicDatabaseWrite<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl<'a, T: EncryptedDataStore> Drop for EncryptedAtomicDatabaseWrite<'a, T> {
fn drop(&mut self) {
info!("Saving database");
if let Err(e) =
atomic_write_encrypted(self.tmp, self.path, &*self.data, &self.key, &self.salt)
{
error!("Failed to save database: {}", e);
}
}
}