use std::collections::HashMap;
use std::fmt::Debug;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use tempfile::NamedTempFile;
use aes_gcm_siv::Aes256GcmSiv;
use elements::hashes::hex::DisplayHex;
use crate::encrypt::{
cipher_from_key_bytes, decrypt_with_nonce_prefix, encrypt_with_deterministic_nonce,
encrypt_with_random_nonce, EncryptError,
};
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub trait Store: Send + Sync + Debug {
type Error: std::error::Error + Send + Sync + 'static;
fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error>;
fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error>;
fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error>;
fn is_persisted(&self) -> bool {
false
}
}
pub trait DynStore: Send + Sync + Debug {
fn get(&self, key: &str) -> Result<Option<Vec<u8>>, BoxError>;
fn put(&self, key: &str, value: &[u8]) -> Result<(), BoxError>;
fn remove(&self, key: &str) -> Result<(), BoxError>;
fn is_persisted(&self) -> bool {
false
}
}
#[derive(Debug)]
pub struct ArcDynStoreError(BoxError);
impl std::fmt::Display for ArcDynStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for ArcDynStoreError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}
impl Store for Arc<dyn DynStore> {
type Error = ArcDynStoreError;
fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, ArcDynStoreError> {
let key = std::str::from_utf8(key.as_ref()).map_err(|e| ArcDynStoreError(Box::new(e)))?;
DynStore::get(self.as_ref(), key).map_err(ArcDynStoreError)
}
fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
) -> Result<(), ArcDynStoreError> {
let key = std::str::from_utf8(key.as_ref()).map_err(|e| ArcDynStoreError(Box::new(e)))?;
DynStore::put(self.as_ref(), key, value.as_ref()).map_err(ArcDynStoreError)
}
fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), ArcDynStoreError> {
let key = std::str::from_utf8(key.as_ref()).map_err(|e| ArcDynStoreError(Box::new(e)))?;
DynStore::remove(self.as_ref(), key).map_err(ArcDynStoreError)
}
fn is_persisted(&self) -> bool {
DynStore::is_persisted(self.as_ref())
}
}
impl<S: Store> DynStore for S {
fn get(&self, key: &str) -> Result<Option<Vec<u8>>, BoxError> {
Store::get(self, key).map_err(|e| Box::new(e) as BoxError)
}
fn put(&self, key: &str, value: &[u8]) -> Result<(), BoxError> {
Store::put(self, key, value).map_err(|e| Box::new(e) as BoxError)
}
fn remove(&self, key: &str) -> Result<(), BoxError> {
Store::remove(self, key).map_err(|e| Box::new(e) as BoxError)
}
fn is_persisted(&self) -> bool {
Store::is_persisted(self)
}
}
#[derive(Debug, Default)]
pub struct MemoryStore {
data: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
}
impl MemoryStore {
pub fn new() -> Self {
Self::default()
}
}
impl Store for MemoryStore {
type Error = std::convert::Infallible;
fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error> {
Ok(self
.data
.lock()
.expect("lock poisoned")
.get(key.as_ref())
.cloned())
}
fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error> {
self.data
.lock()
.expect("lock poisoned")
.insert(key.as_ref().to_vec(), value.as_ref().to_vec());
Ok(())
}
fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error> {
self.data
.lock()
.expect("lock poisoned")
.remove(key.as_ref());
Ok(())
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct FakeStore;
impl FakeStore {
pub fn new() -> Self {
Self
}
}
impl Store for FakeStore {
type Error = std::convert::Infallible;
fn get<K: AsRef<[u8]>>(&self, _key: K) -> Result<Option<Vec<u8>>, Self::Error> {
Ok(None)
}
fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, _key: K, _value: V) -> Result<(), Self::Error> {
Ok(())
}
fn remove<K: AsRef<[u8]>>(&self, _key: K) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Debug)]
pub struct FileStore {
root: Mutex<PathBuf>,
}
impl FileStore {
pub fn new(path: PathBuf) -> Result<Self, std::io::Error> {
if path.is_file() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"FileStore root path is a file",
));
}
if !path.exists() {
fs::create_dir_all(&path)?;
}
Ok(Self {
root: Mutex::new(path),
})
}
fn file_path(root: &Path, key: &[u8]) -> Result<PathBuf, std::io::Error> {
let name = std::str::from_utf8(key).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"store key is not valid UTF-8",
)
})?;
if name.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"store key is empty",
));
}
if name.len() > 255 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"store key exceeds maximum file name length (255 bytes)",
));
}
if name == "."
|| name == ".."
|| name.contains('/')
|| name.contains('\\')
|| name.contains('\0')
|| name.contains(':')
|| name.contains('*')
|| name.contains('?')
|| name.contains('<')
|| name.contains('>')
|| name.contains('|')
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"store key contains invalid file name characters",
));
}
Ok(root.join(name))
}
#[cfg(not(target_os = "windows"))]
fn sync_dir(path: &Path) -> Result<(), std::io::Error> {
fs::File::open(path)?.sync_all()
}
#[cfg(target_os = "windows")]
fn sync_dir(_path: &Path) -> Result<(), std::io::Error> {
Ok(())
}
}
impl Store for FileStore {
type Error = std::io::Error;
fn is_persisted(&self) -> bool {
true
}
fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error> {
let root = self.root.lock().expect("lock poisoned");
let path = Self::file_path(&root, key.as_ref())?;
match fs::read(path) {
Ok(bytes) => Ok(Some(bytes)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e),
}
}
fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error> {
let root = self.root.lock().expect("lock poisoned");
let path = Self::file_path(&root, key.as_ref())?;
let mut tmp = NamedTempFile::new_in(&*root)?;
tmp.write_all(value.as_ref())?;
tmp.as_file().sync_all()?;
match tmp.persist(&path) {
Ok(_) => {}
Err(e) if e.error.kind() == std::io::ErrorKind::AlreadyExists => {
match fs::remove_file(&path) {
Ok(()) => {}
Err(remove_err) if remove_err.kind() == std::io::ErrorKind::NotFound => {}
Err(remove_err) => return Err(remove_err),
}
e.file
.persist(&path)
.map_err(|persist_err| persist_err.error)?;
}
Err(e) => return Err(e.error),
}
Self::sync_dir(root.as_path())?;
Ok(())
}
fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error> {
let root = self.root.lock().expect("lock poisoned");
let path = Self::file_path(&root, key.as_ref())?;
match fs::remove_file(path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
}
}
}
#[derive(Debug)]
pub enum EncryptedStoreError<E: std::error::Error + Send + Sync + 'static> {
Store(E),
Encrypt(EncryptError),
}
impl<E: std::error::Error + Send + Sync + 'static> std::fmt::Display for EncryptedStoreError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EncryptedStoreError::Store(e) => write!(f, "store error: {e}"),
EncryptedStoreError::Encrypt(e) => write!(f, "encryption error: {e}"),
}
}
}
impl<E: std::error::Error + Send + Sync + 'static> std::error::Error for EncryptedStoreError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
EncryptedStoreError::Store(e) => Some(e),
EncryptedStoreError::Encrypt(e) => Some(e),
}
}
}
#[derive(Debug)]
pub struct EncryptedStore<S> {
inner: S,
key_bytes: [u8; 32],
encrypt_keys: bool,
}
impl<S> EncryptedStore<S> {
pub fn new(inner: S, key_bytes: [u8; 32]) -> Self {
Self {
inner,
key_bytes,
encrypt_keys: false,
}
}
pub fn new_with_key_encryption(inner: S, key_bytes: [u8; 32]) -> Self {
Self {
inner,
key_bytes,
encrypt_keys: true,
}
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
pub fn cipher(&self) -> Aes256GcmSiv {
cipher_from_key_bytes(self.key_bytes)
}
}
impl<S: Store> EncryptedStore<S> {
fn effective_key<K: AsRef<[u8]>>(
&self,
key: K,
) -> Result<Vec<u8>, EncryptedStoreError<S::Error>> {
if self.encrypt_keys {
let mut cipher = cipher_from_key_bytes(self.key_bytes);
let encrypted = encrypt_with_deterministic_nonce(&mut cipher, key.as_ref())
.map_err(EncryptedStoreError::Encrypt)?;
Ok(encrypted.to_lower_hex_string().into_bytes())
} else {
Ok(key.as_ref().to_vec())
}
}
}
impl<S: Store> Store for EncryptedStore<S> {
type Error = EncryptedStoreError<S::Error>;
fn is_persisted(&self) -> bool {
self.inner.is_persisted()
}
fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, Self::Error> {
let key = self.effective_key(key)?;
match self.inner.get(&key).map_err(EncryptedStoreError::Store)? {
Some(ciphertext) => {
let mut cipher = cipher_from_key_bytes(self.key_bytes);
let plaintext = decrypt_with_nonce_prefix(&mut cipher, &ciphertext)
.map_err(EncryptedStoreError::Encrypt)?;
Ok(Some(plaintext))
}
None => Ok(None),
}
}
fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(&self, key: K, value: V) -> Result<(), Self::Error> {
let key = self.effective_key(key)?;
let mut cipher = cipher_from_key_bytes(self.key_bytes);
let ciphertext = encrypt_with_random_nonce(&mut cipher, value.as_ref())
.map_err(EncryptedStoreError::Encrypt)?;
self.inner
.put(&key, ciphertext)
.map_err(EncryptedStoreError::Store)?;
Ok(())
}
fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), Self::Error> {
let key = self.effective_key(key)?;
self.inner
.remove(&key)
.map_err(EncryptedStoreError::Store)?;
Ok(())
}
}
#[cfg(test)]
mod test {
use super::{EncryptedStore, FakeStore, FileStore, MemoryStore, Store};
#[test]
fn memory_store() {
let store = MemoryStore::new();
assert_eq!(store.get("key").unwrap(), None);
store.put("key", b"value").unwrap();
assert_eq!(store.get("key").unwrap(), Some(b"value".to_vec()));
store.put("key", b"new_value").unwrap();
assert_eq!(store.get("key").unwrap(), Some(b"new_value".to_vec()));
store.remove("key").unwrap();
assert_eq!(store.get("key").unwrap(), None);
store.remove("key").unwrap();
}
#[test]
fn file_store_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let store = FileStore::new(dir.path().to_path_buf()).unwrap();
assert_eq!(store.get("key").unwrap(), None);
store.put("key", b"value").unwrap();
assert_eq!(store.get("key").unwrap(), Some(b"value".to_vec()));
store.put("key2", b"value2").unwrap();
assert_eq!(store.get("key2").unwrap(), Some(b"value2".to_vec()));
store.put("key", b"new_value").unwrap();
assert_eq!(store.get("key").unwrap(), Some(b"new_value".to_vec()));
let non_utf8_key = [0u8, 255u8, 1u8];
assert!(store.put(non_utf8_key, b"bin").is_err());
store.remove("key").unwrap();
assert_eq!(store.get("key").unwrap(), None);
store.remove("key").unwrap();
drop(store);
let store = FileStore::new(dir.path().to_path_buf()).unwrap();
assert_eq!(store.get("key").unwrap(), None);
assert_eq!(store.get("key2").unwrap(), Some(b"value2".to_vec()));
}
#[test]
fn fake_store() {
let store = FakeStore::new();
assert_eq!(store.get("key").unwrap(), None);
store.put("key", b"value").unwrap();
assert_eq!(store.get("key").unwrap(), None);
store.remove("key").unwrap();
}
#[test]
fn encrypted_store_memory() {
let key_bytes = [7u8; 32];
let inner = MemoryStore::new();
let store = EncryptedStore::new(inner, key_bytes);
assert_eq!(store.get("key").unwrap(), None);
store.put("key", b"secret value").unwrap();
assert_eq!(store.get("key").unwrap(), Some(b"secret value".to_vec()));
let raw = store.inner().get("key").unwrap().unwrap();
assert_ne!(raw, b"secret value".to_vec());
store.put("key", b"new secret").unwrap();
assert_eq!(store.get("key").unwrap(), Some(b"new secret".to_vec()));
store.remove("key").unwrap();
assert_eq!(store.get("key").unwrap(), None);
}
#[test]
fn encrypted_store_file() {
let key_bytes = [42u8; 32];
let dir = tempfile::tempdir().unwrap();
let inner = FileStore::new(dir.path().to_path_buf()).unwrap();
let store = EncryptedStore::new(inner, key_bytes);
store.put("000000000000", b"update data").unwrap();
assert_eq!(
store.get("000000000000").unwrap(),
Some(b"update data".to_vec())
);
let file_path = dir.path().join("000000000000");
let raw_bytes = std::fs::read(&file_path).unwrap();
assert_ne!(raw_bytes, b"update data".to_vec());
drop(store);
let inner = FileStore::new(dir.path().to_path_buf()).unwrap();
let store = EncryptedStore::new(inner, key_bytes);
assert_eq!(
store.get("000000000000").unwrap(),
Some(b"update data".to_vec())
);
let inner = FileStore::new(dir.path().to_path_buf()).unwrap();
let wrong_store = EncryptedStore::new(inner, [0u8; 32]);
assert!(wrong_store.get("000000000000").is_err());
}
}