use crate::encrypt::{decrypt_with_password, encrypt, encryption_key_from_pass, Cipher};
use crate::error::{MutinyError, MutinyStorageError};
use crate::ldkstorage::CHANNEL_MANAGER_KEY;
use crate::nodemanager::{NodeStorage, DEVICE_LOCK_INTERVAL_SECS};
use crate::utils::{now, spawn};
use crate::vss::{MutinyVssClient, VssKeyValueItem};
use bdk::chain::{Append, PersistBackend};
use bip39::Mnemonic;
use lightning::log_error;
use lightning::util::logger::Logger;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
pub const KEYCHAIN_STORE_KEY: &str = "bdk_keychain";
pub(crate) const MNEMONIC_KEY: &str = "mnemonic";
pub const NODES_KEY: &str = "nodes";
const FEE_ESTIMATES_KEY: &str = "fee_estimates";
const FIRST_SYNC_KEY: &str = "first_sync";
pub(crate) const DEVICE_ID_KEY: &str = "device_id";
pub const DEVICE_LOCK_KEY: &str = "device_lock";
fn needs_encryption(key: &str) -> bool {
match key {
MNEMONIC_KEY => true,
str if str.starts_with(CHANNEL_MANAGER_KEY) => true,
_ => false,
}
}
pub fn encrypt_value(
key: impl AsRef<str>,
value: Value,
cipher: Option<Cipher>,
) -> Result<Value, MutinyError> {
let res = match cipher {
Some(c) if needs_encryption(key.as_ref()) => {
let str = serde_json::to_string(&value)?;
let ciphertext = encrypt(&str, c)?;
Value::String(ciphertext)
}
_ => value,
};
Ok(res)
}
pub fn decrypt_value(
key: impl AsRef<str>,
value: Value,
password: Option<&str>,
) -> Result<Value, MutinyError> {
let json: Value = match password {
Some(pw) if needs_encryption(key.as_ref()) => {
let str: String = serde_json::from_value(value)?;
let ciphertext = decrypt_with_password(&str, pw)?;
serde_json::from_str(&ciphertext)?
}
_ => value,
};
Ok(json)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionedValue {
pub version: u32,
pub value: Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DeviceLock {
pub time: u32,
pub device: String,
}
impl DeviceLock {
pub fn is_locked(&self, id: &str) -> bool {
let now = now().as_secs();
let diff = now - self.time as u64;
diff < DEVICE_LOCK_INTERVAL_SECS * 2 && self.device != id
}
}
pub trait MutinyStorage: Clone + Sized + 'static {
fn password(&self) -> Option<&str>;
fn cipher(&self) -> Option<Cipher>;
fn vss_client(&self) -> Option<Arc<MutinyVssClient>>;
fn set<T>(&self, key: impl AsRef<str>, value: T) -> Result<(), MutinyError>
where
T: Serialize;
fn set_data<T>(
&self,
key: impl AsRef<str>,
value: T,
version: Option<u32>,
) -> Result<(), MutinyError>
where
T: Serialize,
{
let data = serde_json::to_value(value).map_err(|e| MutinyError::PersistenceFailed {
source: MutinyStorageError::SerdeError { source: e },
})?;
if let (Some(vss), Some(version)) = (self.vss_client(), version) {
let item = VssKeyValueItem {
key: key.as_ref().to_string(),
value: data.clone(),
version,
};
spawn(async move {
if let Err(e) = vss.put_objects(vec![item]).await {
log_error!(vss.logger, "Failed to put object in VSS: {e}");
}
});
}
let json: Value = encrypt_value(key.as_ref(), data, self.cipher())?;
self.set(key, json)
}
fn get<T>(&self, key: impl AsRef<str>) -> Result<Option<T>, MutinyError>
where
T: for<'de> Deserialize<'de>;
fn get_data<T>(&self, key: impl AsRef<str>) -> Result<Option<T>, MutinyError>
where
T: for<'de> Deserialize<'de>,
{
match self.get(&key)? {
None => Ok(None),
Some(value) => {
let json: Value = decrypt_value(&key, value, self.password())?;
let data: T = serde_json::from_value(json)?;
Ok(Some(data))
}
}
}
fn delete(&self, keys: &[impl AsRef<str>]) -> Result<(), MutinyError>;
async fn start(&mut self) -> Result<(), MutinyError>;
fn stop(&self);
fn connected(&self) -> Result<bool, MutinyError>;
fn scan_keys(&self, prefix: &str, suffix: Option<&str>) -> Result<Vec<String>, MutinyError>;
fn scan<T>(&self, prefix: &str, suffix: Option<&str>) -> Result<HashMap<String, T>, MutinyError>
where
T: for<'de> Deserialize<'de>,
{
let keys = self.scan_keys(prefix, suffix)?;
let mut map = HashMap::with_capacity(keys.len());
for key in keys {
let kv = self.get_data::<T>(&key)?;
if let Some(v) = kv {
map.insert(key, v);
}
}
Ok(map)
}
fn insert_mnemonic(&self, mnemonic: Mnemonic) -> Result<Mnemonic, MutinyError> {
self.set_data(MNEMONIC_KEY, &mnemonic, None)?;
Ok(mnemonic)
}
fn get_mnemonic(&self) -> Result<Option<Mnemonic>, MutinyError> {
self.get_data(MNEMONIC_KEY)
}
fn change_password(
&mut self,
new: Option<String>,
new_cipher: Option<Cipher>,
) -> Result<(), MutinyError>;
fn change_password_and_rewrite_storage(
&mut self,
old: Option<String>,
new: Option<String>,
) -> Result<(), MutinyError> {
if old != self.password().map(|s| s.to_owned()) {
return Err(MutinyError::IncorrectPassword);
}
let mut keys: Vec<String> = self.scan_keys("", None)?;
keys.retain(|k| needs_encryption(k));
let mut values: HashMap<String, Value> = HashMap::new();
for key in keys.iter() {
let value = self.get_data(key)?;
if let Some(v) = value {
values.insert(key.to_owned(), v);
}
}
let new_cipher = new
.as_ref()
.filter(|p| !p.is_empty())
.map(|p| encryption_key_from_pass(p))
.transpose()?;
self.change_password(new, new_cipher)?;
for (key, value) in values.iter() {
self.set_data(key, value, None)?;
}
Ok(())
}
async fn import(json: Value) -> Result<(), MutinyError>;
async fn clear() -> Result<(), MutinyError>;
async fn delete_all(&self) -> Result<(), MutinyError> {
Self::clear().await?;
if self.vss_client().is_some() {
let device = self.get_device_id()?;
let lock = DeviceLock { time: 0, device };
let time = now().as_secs() as u32;
self.set_data(DEVICE_LOCK_KEY, lock, Some(time))?;
}
Ok(())
}
fn get_nodes(&self) -> Result<NodeStorage, MutinyError> {
let res: Option<NodeStorage> = self.get_data(NODES_KEY)?;
match res {
Some(nodes) => Ok(nodes),
None => Ok(NodeStorage::default()),
}
}
fn insert_nodes(&self, nodes: NodeStorage) -> Result<(), MutinyError> {
let version = Some(nodes.version);
self.set_data(NODES_KEY, nodes, version)
}
fn get_fee_estimates(&self) -> Result<Option<HashMap<String, f64>>, MutinyError> {
self.get_data(FEE_ESTIMATES_KEY)
}
fn insert_fee_estimates(&self, fees: HashMap<String, f64>) -> Result<(), MutinyError> {
self.set_data(FEE_ESTIMATES_KEY, fees, None)
}
fn has_done_first_sync(&self) -> Result<bool, MutinyError> {
self.get_data::<bool>(FIRST_SYNC_KEY)
.map(|v| v == Some(true))
}
fn set_done_first_sync(&self) -> Result<(), MutinyError> {
self.set_data(FIRST_SYNC_KEY, true, None)
}
fn get_device_id(&self) -> Result<String, MutinyError> {
match self.get_data(DEVICE_ID_KEY)? {
Some(id) => Ok(id),
None => {
let new_id = Uuid::new_v4().to_string();
self.set_data(DEVICE_ID_KEY, &new_id, None)?;
Ok(new_id)
}
}
}
fn get_device_lock(&self) -> Result<Option<DeviceLock>, MutinyError> {
self.get_data(DEVICE_LOCK_KEY)
}
fn set_device_lock(&self) -> Result<(), MutinyError> {
let device = self.get_device_id()?;
if let Some(lock) = self.get_device_lock()? {
if lock.is_locked(&device) {
return Err(MutinyError::AlreadyRunning);
}
}
let time = now().as_secs() as u32;
let lock = DeviceLock { time, device };
self.set_data(DEVICE_LOCK_KEY, lock, Some(time))
}
}
#[derive(Clone)]
pub struct MemoryStorage {
pub password: Option<String>,
pub cipher: Option<Cipher>,
pub memory: Arc<RwLock<HashMap<String, Value>>>,
pub vss_client: Option<Arc<MutinyVssClient>>,
}
impl MemoryStorage {
pub fn new(
password: Option<String>,
cipher: Option<Cipher>,
vss_client: Option<Arc<MutinyVssClient>>,
) -> Self {
Self {
cipher,
password,
memory: Arc::new(RwLock::new(HashMap::new())),
vss_client,
}
}
pub async fn load_from_vss(&self) -> Result<(), MutinyError> {
if let Some(vss) = self.vss_client() {
let keys = vss.list_key_versions(None).await?;
let mut items = HashMap::new();
for key in keys {
let obj = vss.get_object(&key.key).await?;
items.insert(key.key, obj.value);
}
let mut map = self
.memory
.try_write()
.map_err(|e| MutinyError::write_err(e.into()))?;
map.extend(items);
}
Ok(())
}
}
impl Default for MemoryStorage {
fn default() -> Self {
Self::new(None, None, None)
}
}
impl MutinyStorage for MemoryStorage {
fn password(&self) -> Option<&str> {
self.password.as_deref()
}
fn cipher(&self) -> Option<Cipher> {
self.cipher.to_owned()
}
fn vss_client(&self) -> Option<Arc<MutinyVssClient>> {
self.vss_client.clone()
}
fn set<T>(&self, key: impl AsRef<str>, value: T) -> Result<(), MutinyError>
where
T: Serialize,
{
let key = key.as_ref().to_string();
let data = serde_json::to_value(value).map_err(|e| MutinyError::PersistenceFailed {
source: MutinyStorageError::SerdeError { source: e },
})?;
let mut map = self
.memory
.try_write()
.map_err(|e| MutinyError::write_err(e.into()))?;
map.insert(key, data);
Ok(())
}
fn get<T>(&self, key: impl AsRef<str>) -> Result<Option<T>, MutinyError>
where
T: for<'de> Deserialize<'de>,
{
let map = self
.memory
.try_read()
.map_err(|e| MutinyError::read_err(e.into()))?;
match map.get(key.as_ref()) {
None => Ok(None),
Some(value) => {
let data: T = serde_json::from_value(value.to_owned())?;
Ok(Some(data))
}
}
}
fn delete(&self, keys: &[impl AsRef<str>]) -> Result<(), MutinyError> {
let keys: Vec<String> = keys.iter().map(|k| k.as_ref().to_string()).collect();
let mut map = self
.memory
.try_write()
.map_err(|e| MutinyError::write_err(e.into()))?;
for key in keys {
map.remove(&key);
}
Ok(())
}
async fn start(&mut self) -> Result<(), MutinyError> {
Ok(())
}
fn stop(&self) {}
fn connected(&self) -> Result<bool, MutinyError> {
Ok(false)
}
fn scan_keys(&self, prefix: &str, suffix: Option<&str>) -> Result<Vec<String>, MutinyError> {
let map = self
.memory
.try_read()
.map_err(|e| MutinyError::read_err(e.into()))?;
Ok(map
.keys()
.filter(|key| {
key.starts_with(prefix) && (suffix.is_none() || key.ends_with(suffix.unwrap()))
})
.cloned()
.collect())
}
fn change_password(
&mut self,
new: Option<String>,
new_cipher: Option<Cipher>,
) -> Result<(), MutinyError> {
self.password = new;
self.cipher = new_cipher;
Ok(())
}
async fn import(_json: Value) -> Result<(), MutinyError> {
Ok(())
}
async fn clear() -> Result<(), MutinyError> {
Ok(())
}
}
impl MutinyStorage for () {
fn password(&self) -> Option<&str> {
None
}
fn cipher(&self) -> Option<Cipher> {
None
}
fn vss_client(&self) -> Option<Arc<MutinyVssClient>> {
None
}
fn set<T>(&self, _key: impl AsRef<str>, _value: T) -> Result<(), MutinyError>
where
T: Serialize,
{
Ok(())
}
fn get<T>(&self, _key: impl AsRef<str>) -> Result<Option<T>, MutinyError>
where
T: for<'de> Deserialize<'de>,
{
Ok(None)
}
fn delete(&self, _keys: &[impl AsRef<str>]) -> Result<(), MutinyError> {
Ok(())
}
async fn start(&mut self) -> Result<(), MutinyError> {
Ok(())
}
fn stop(&self) {}
fn connected(&self) -> Result<bool, MutinyError> {
Ok(false)
}
fn scan_keys(&self, _prefix: &str, _suffix: Option<&str>) -> Result<Vec<String>, MutinyError> {
Ok(Vec::new())
}
fn change_password(
&mut self,
_new: Option<String>,
_new_cipher: Option<Cipher>,
) -> Result<(), MutinyError> {
Ok(())
}
async fn import(_json: Value) -> Result<(), MutinyError> {
Ok(())
}
async fn clear() -> Result<(), MutinyError> {
Ok(())
}
}
#[derive(Clone)]
pub struct OnChainStorage<S: MutinyStorage>(pub(crate) S);
impl<K, S: MutinyStorage> PersistBackend<K> for OnChainStorage<S>
where
K: Default + Clone + Append + serde::Serialize + serde::de::DeserializeOwned,
{
type WriteError = MutinyError;
type LoadError = MutinyError;
fn write_changes(&mut self, changeset: &K) -> Result<(), Self::WriteError> {
if changeset.is_empty() {
return Ok(());
}
match self.0.get_data::<K>(KEYCHAIN_STORE_KEY)? {
Some(mut keychain_store) => {
keychain_store.append(changeset.clone());
self.0.set_data(KEYCHAIN_STORE_KEY, keychain_store, None)
}
None => self.0.set_data(KEYCHAIN_STORE_KEY, changeset, None),
}
}
fn load_from_persistence(&mut self) -> Result<K, Self::LoadError> {
if let Some(k) = self.0.get_data(KEYCHAIN_STORE_KEY)? {
Ok(k)
} else {
Ok(K::default())
}
}
}
#[cfg(test)]
mod tests {
use crate::test_utils::*;
use crate::utils::sleep;
use crate::{encrypt::encryption_key_from_pass, storage::MemoryStorage};
use crate::{keymanager, storage::MutinyStorage};
use wasm_bindgen_test::{wasm_bindgen_test as test, wasm_bindgen_test_configure};
wasm_bindgen_test_configure!(run_in_browser);
#[test]
fn insert_and_get_mnemonic_no_password() {
let test_name = "insert_and_get_mnemonic_no_password";
log!("{}", test_name);
let seed = keymanager::generate_seed(12).unwrap();
let storage = MemoryStorage::default();
let mnemonic = storage.insert_mnemonic(seed).unwrap();
let stored_mnemonic = storage.get_mnemonic().unwrap();
assert_eq!(Some(mnemonic), stored_mnemonic);
}
#[test]
fn insert_and_get_mnemonic_with_password() {
let test_name = "insert_and_get_mnemonic_with_password";
log!("{}", test_name);
let seed = keymanager::generate_seed(12).unwrap();
let pass = uuid::Uuid::new_v4().to_string();
let cipher = encryption_key_from_pass(&pass).unwrap();
let storage = MemoryStorage::new(Some(pass), Some(cipher), None);
let mnemonic = storage.insert_mnemonic(seed).unwrap();
let stored_mnemonic = storage.get_mnemonic().unwrap();
assert_eq!(Some(mnemonic), stored_mnemonic);
}
#[test]
async fn test_device_lock() {
let test_name = "test_device_lock";
log!("{}", test_name);
let vss = std::sync::Arc::new(create_vss_client().await);
let storage = MemoryStorage::new(None, None, Some(vss.clone()));
storage.load_from_vss().await.unwrap();
let id = storage.get_device_id().unwrap();
let lock = storage.get_device_lock().unwrap();
assert_eq!(None, lock);
storage.set_device_lock().unwrap();
sleep(1_000).await;
let lock = storage.get_device_lock().unwrap();
assert!(lock.is_some());
assert!(!lock.clone().unwrap().is_locked(&id));
assert!(lock.clone().unwrap().is_locked("different_id"));
assert_eq!(lock.unwrap().device, id);
storage.set_device_lock().unwrap();
sleep(1_000).await;
let storage = MemoryStorage::new(None, None, Some(vss));
storage.load_from_vss().await.unwrap();
let new_id = storage.get_device_id().unwrap();
assert_ne!(id, new_id);
let lock = storage.get_device_lock().unwrap();
assert!(lock.is_some());
assert!(!lock.clone().unwrap().is_locked(&id));
assert!(lock.clone().unwrap().is_locked(&new_id));
assert_eq!(lock.unwrap().device, id);
assert!(storage.set_device_lock().is_err())
}
}