use std::boxed::Box;
use std::collections::{hash_map, HashMap};
use std::future::Future;
use std::panic::RefUnwindSafe;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Mutex;
use lightning::events::ClosureReason;
use lightning::ln::functional_test_utils::{
connect_block, create_announced_chan_between_nodes, create_chanmon_cfgs, create_dummy_block,
create_network, create_node_cfgs, create_node_chanmgrs, send_payment, TestChanMonCfg,
};
use lightning::util::persist::{
KVStore, KVStoreSync, MonitorUpdatingPersister, KVSTORE_NAMESPACE_KEY_MAX_LEN,
};
use lightning::util::test_utils;
use lightning::{check_added_monitors, check_closed_broadcast, check_closed_event, io};
use rand::distr::Alphanumeric;
use rand::{rng, Rng};
type TestMonitorUpdatePersister<'a, K> = MonitorUpdatingPersister<
&'a K,
&'a test_utils::TestLogger,
&'a test_utils::TestKeysInterface,
&'a test_utils::TestKeysInterface,
&'a test_utils::TestBroadcaster,
&'a test_utils::TestFeeEstimator,
>;
const EXPECTED_UPDATES_PER_PAYMENT: u64 = 5;
pub struct InMemoryStore {
persisted_bytes: Mutex<HashMap<String, HashMap<String, Vec<u8>>>>,
}
impl InMemoryStore {
pub fn new() -> Self {
let persisted_bytes = Mutex::new(HashMap::new());
Self { persisted_bytes }
}
fn read_internal(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str,
) -> io::Result<Vec<u8>> {
let persisted_lock = self.persisted_bytes.lock().unwrap();
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
if let Some(outer_ref) = persisted_lock.get(&prefixed) {
if let Some(inner_ref) = outer_ref.get(key) {
let bytes = inner_ref.clone();
Ok(bytes)
} else {
Err(io::Error::new(io::ErrorKind::NotFound, "Key not found"))
}
} else {
Err(io::Error::new(io::ErrorKind::NotFound, "Namespace not found"))
}
}
fn write_internal(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec<u8>,
) -> io::Result<()> {
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
let outer_e = persisted_lock.entry(prefixed).or_insert(HashMap::new());
outer_e.insert(key.to_string(), buf);
Ok(())
}
fn remove_internal(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool,
) -> io::Result<()> {
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
if let Some(outer_ref) = persisted_lock.get_mut(&prefixed) {
outer_ref.remove(&key.to_string());
}
Ok(())
}
fn list_internal(
&self, primary_namespace: &str, secondary_namespace: &str,
) -> io::Result<Vec<String>> {
let mut persisted_lock = self.persisted_bytes.lock().unwrap();
let prefixed = format!("{primary_namespace}/{secondary_namespace}");
match persisted_lock.entry(prefixed) {
hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()),
hash_map::Entry::Vacant(_) => Ok(Vec::new()),
}
}
}
impl KVStore for InMemoryStore {
fn read(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str,
) -> Pin<Box<dyn Future<Output = Result<Vec<u8>, io::Error>> + 'static + Send>> {
let res = self.read_internal(&primary_namespace, &secondary_namespace, &key);
Box::pin(async move { res })
}
fn write(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec<u8>,
) -> Pin<Box<dyn Future<Output = Result<(), io::Error>> + 'static + Send>> {
let res = self.write_internal(&primary_namespace, &secondary_namespace, &key, buf);
Box::pin(async move { res })
}
fn remove(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool,
) -> Pin<Box<dyn Future<Output = Result<(), io::Error>> + 'static + Send>> {
let res = self.remove_internal(&primary_namespace, &secondary_namespace, &key, lazy);
Box::pin(async move { res })
}
fn list(
&self, primary_namespace: &str, secondary_namespace: &str,
) -> Pin<Box<dyn Future<Output = Result<Vec<String>, io::Error>> + 'static + Send>> {
let res = self.list_internal(primary_namespace, secondary_namespace);
Box::pin(async move { res })
}
}
impl KVStoreSync for InMemoryStore {
fn read(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str,
) -> io::Result<Vec<u8>> {
self.read_internal(primary_namespace, secondary_namespace, key)
}
fn write(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec<u8>,
) -> io::Result<()> {
self.write_internal(primary_namespace, secondary_namespace, key, buf)
}
fn remove(
&self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool,
) -> io::Result<()> {
self.remove_internal(primary_namespace, secondary_namespace, key, lazy)
}
fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result<Vec<String>> {
self.list_internal(primary_namespace, secondary_namespace)
}
}
unsafe impl Sync for InMemoryStore {}
unsafe impl Send for InMemoryStore {}
pub(crate) fn random_storage_path() -> PathBuf {
let mut temp_path = std::env::temp_dir();
let mut rng = rng();
let rand_dir: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect();
temp_path.push(rand_dir);
temp_path
}
pub(crate) fn do_read_write_remove_list_persist<K: KVStoreSync + RefUnwindSafe>(kv_store: &K) {
let data = vec![42u8; 32];
let primary_namespace = "testspace";
let secondary_namespace = "testsubspace";
let key = "testkey";
kv_store.write(primary_namespace, secondary_namespace, key, data.clone()).unwrap();
kv_store.write("", "", key, data.clone()).unwrap();
let res =
std::panic::catch_unwind(|| kv_store.write("", secondary_namespace, key, data.clone()));
assert!(res.is_err());
let res = std::panic::catch_unwind(|| {
kv_store.write(primary_namespace, secondary_namespace, "", data.clone())
});
assert!(res.is_err());
let listed_keys = kv_store.list(primary_namespace, secondary_namespace).unwrap();
assert_eq!(listed_keys.len(), 1);
assert_eq!(listed_keys[0], key);
let read_data = kv_store.read(primary_namespace, secondary_namespace, key).unwrap();
assert_eq!(data, &*read_data);
kv_store.remove(primary_namespace, secondary_namespace, key, false).unwrap();
let listed_keys = kv_store.list(primary_namespace, secondary_namespace).unwrap();
assert_eq!(listed_keys.len(), 0);
let max_chars: String = std::iter::repeat('A').take(KVSTORE_NAMESPACE_KEY_MAX_LEN).collect();
kv_store.write(&max_chars, &max_chars, &max_chars, data.clone()).unwrap();
let listed_keys = kv_store.list(&max_chars, &max_chars).unwrap();
assert_eq!(listed_keys.len(), 1);
assert_eq!(listed_keys[0], max_chars);
let read_data = kv_store.read(&max_chars, &max_chars, &max_chars).unwrap();
assert_eq!(data, &*read_data);
kv_store.remove(&max_chars, &max_chars, &max_chars, false).unwrap();
let listed_keys = kv_store.list(&max_chars, &max_chars).unwrap();
assert_eq!(listed_keys.len(), 0);
}
pub(crate) fn create_persister<'a, K: KVStoreSync + Sync>(
store: &'a K, chanmon_cfg: &'a TestChanMonCfg, max_pending_updates: u64,
) -> TestMonitorUpdatePersister<'a, K> {
MonitorUpdatingPersister::new(
store,
&chanmon_cfg.logger,
max_pending_updates,
&chanmon_cfg.keys_manager,
&chanmon_cfg.keys_manager,
&chanmon_cfg.tx_broadcaster,
&chanmon_cfg.fee_estimator,
)
}
pub(crate) fn create_chain_monitor<'a, K: KVStoreSync + Sync>(
chanmon_cfg: &'a TestChanMonCfg, persister: &'a TestMonitorUpdatePersister<'a, K>,
) -> test_utils::TestChainMonitor<'a> {
test_utils::TestChainMonitor::new(
Some(&chanmon_cfg.chain_source),
&chanmon_cfg.tx_broadcaster,
&chanmon_cfg.logger,
&chanmon_cfg.fee_estimator,
persister,
&chanmon_cfg.keys_manager,
)
}
pub(crate) fn do_test_store<K: KVStoreSync + Sync>(store_0: &K, store_1: &K) {
let persister_0_max_pending_updates = 7;
let persister_1_max_pending_updates = 3;
let chanmon_cfgs = create_chanmon_cfgs(2);
let persister_0 = create_persister(store_0, &chanmon_cfgs[0], persister_0_max_pending_updates);
let persister_1 = create_persister(store_1, &chanmon_cfgs[1], persister_1_max_pending_updates);
let chain_mon_0 = create_chain_monitor(&chanmon_cfgs[0], &persister_0);
let chain_mon_1 = create_chain_monitor(&chanmon_cfgs[1], &persister_1);
let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
node_cfgs[0].chain_monitor = chain_mon_0;
node_cfgs[1].chain_monitor = chain_mon_1;
let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
let mut persisted_chan_data_0 = persister_0.read_all_channel_monitors_with_updates().unwrap();
assert_eq!(persisted_chan_data_0.len(), 0);
let mut persisted_chan_data_1 = persister_1.read_all_channel_monitors_with_updates().unwrap();
assert_eq!(persisted_chan_data_1.len(), 0);
macro_rules! check_persisted_data {
($expected_update_id:expr) => {
persisted_chan_data_0 = persister_0.read_all_channel_monitors_with_updates().unwrap();
assert_eq!(persisted_chan_data_0.len(), 1);
for (_, mon) in persisted_chan_data_0.iter() {
assert_eq!(mon.get_latest_update_id(), $expected_update_id);
}
persisted_chan_data_1 = persister_1.read_all_channel_monitors_with_updates().unwrap();
assert_eq!(persisted_chan_data_1.len(), 1);
for (_, mon) in persisted_chan_data_1.iter() {
assert_eq!(mon.get_latest_update_id(), $expected_update_id);
}
};
}
let _ = create_announced_chan_between_nodes(&nodes, 0, 1);
check_persisted_data!(0);
let expected_route = &[&nodes[1]][..];
send_payment(&nodes[0], expected_route, 8_000_000);
check_persisted_data!(EXPECTED_UPDATES_PER_PAYMENT);
let expected_route = &[&nodes[0]][..];
send_payment(&nodes[1], expected_route, 4_000_000);
check_persisted_data!(2 * EXPECTED_UPDATES_PER_PAYMENT);
let mut sender = 0;
for i in 3..=persister_0_max_pending_updates * 2 {
let receiver;
if sender == 0 {
sender = 1;
receiver = 0;
} else {
sender = 0;
receiver = 1;
}
let expected_route = &[&nodes[receiver]][..];
send_payment(&nodes[sender], expected_route, 21_000);
check_persisted_data!(i * EXPECTED_UPDATES_PER_PAYMENT);
}
let message = "Channel force-closed".to_owned();
nodes[0]
.node
.force_close_broadcasting_latest_txn(
&nodes[0].node.list_channels()[0].channel_id,
&nodes[1].node.get_our_node_id(),
message.clone(),
)
.unwrap();
check_closed_event!(
nodes[0],
1,
ClosureReason::HolderForceClosed { broadcasted_latest_txn: Some(true), message },
[nodes[1].node.get_our_node_id()],
100000
);
check_closed_broadcast!(nodes[0], true);
check_added_monitors!(nodes[0], 1);
let node_txn = nodes[0].tx_broadcaster.txn_broadcast();
assert_eq!(node_txn.len(), 1);
let txn = vec![node_txn[0].clone(), node_txn[0].clone()];
let dummy_block = create_dummy_block(nodes[0].best_block_hash(), 42, txn);
connect_block(&nodes[1], &dummy_block);
check_closed_broadcast!(nodes[1], true);
let reason = ClosureReason::CommitmentTxConfirmed;
let node_id_0 = nodes[0].node.get_our_node_id();
check_closed_event!(nodes[1], 1, reason, false, [node_id_0], 100000);
check_added_monitors!(nodes[1], 1);
check_persisted_data!(persister_0_max_pending_updates * 2 * EXPECTED_UPDATES_PER_PAYMENT + 1);
}