use std::{collections::BTreeSet, mem, pin::Pin};
use amaru_kernel::cbor;
use amaru_ouroboros_traits::{
CanValidateTransactions, Mempool, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
TxSubmissionMempool,
};
use parking_lot::RwLock;
#[derive(Debug, Default)]
pub struct DummyMempool<T> {
inner: RwLock<DummyMempoolInner<T>>,
}
impl<T> DummyMempool<T> {
pub fn new() -> Self {
DummyMempool { inner: RwLock::new(DummyMempoolInner::new()) }
}
}
#[derive(Debug, Default)]
pub struct DummyMempoolInner<T> {
transactions: Vec<T>,
}
impl<Tx: cbor::Encode<()> + Send + Sync + 'static> CanValidateTransactions<Tx> for DummyMempool<Tx> {
fn validate_transaction(&self, _tx: Tx) -> Result<(), TransactionValidationError> {
Ok(())
}
}
impl<Tx: cbor::Encode<()> + Send + Sync + 'static> TxSubmissionMempool<Tx> for DummyMempool<Tx> {
fn insert(&self, tx: Tx, _tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
let tx_id = TxId::from(&tx);
let mut inner = self.inner.write();
inner.transactions.push(tx);
let seq_no = MempoolSeqNo(inner.transactions.len() as u64 - 1);
Ok((tx_id, seq_no))
}
fn get_tx(&self, _tx_id: &TxId) -> Option<Tx> {
None
}
fn tx_ids_since(&self, _from_seq: MempoolSeqNo, _limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
vec![]
}
fn wait_for_at_least(&self, _seq_no: MempoolSeqNo) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
Box::pin(async move { true })
}
fn get_txs_for_ids(&self, _ids: &[TxId]) -> Vec<Tx> {
vec![]
}
fn last_seq_no(&self) -> MempoolSeqNo {
MempoolSeqNo(self.inner.read().transactions.len() as u64)
}
}
impl<Tx: cbor::Encode<()> + Send + Sync + 'static> Mempool<Tx> for DummyMempool<Tx> {
fn take(&self) -> Vec<Tx> {
mem::take(&mut self.inner.write().transactions)
}
fn acknowledge<TxKey: Ord, I>(&self, tx: &Tx, keys: fn(&Tx) -> I)
where
I: IntoIterator<Item = TxKey>,
Self: Sized,
{
let keys_to_remove = BTreeSet::from_iter(keys(tx));
self.inner.write().transactions.retain(|tx| !keys(tx).into_iter().any(|k| keys_to_remove.contains(&k)));
}
}
impl<T> DummyMempoolInner<T> {
pub fn new() -> Self {
DummyMempoolInner { transactions: vec![] }
}
}
#[cfg(test)]
mod tests {
use amaru_kernel::cbor;
use super::*;
#[test]
fn take_empty() {
let mempool: DummyMempool<FakeTx<'_>> = DummyMempool::new();
assert_eq!(mempool.take(), vec![]);
}
#[test]
fn add_then_take() {
let mempool = DummyMempool::new();
mempool.add(FakeTx::new("tx1", &[1])).unwrap();
mempool.add(FakeTx::new("tx2", &[2])).unwrap();
assert_eq!(mempool.take(), vec![FakeTx::new("tx1", &[1]), FakeTx::new("tx2", &[2])]);
}
#[test]
fn invalidate_entries() {
let mempool = DummyMempool::new();
mempool.add(FakeTx::new("tx1", &[1, 2])).unwrap();
mempool.add(FakeTx::new("tx2", &[3, 4])).unwrap();
mempool.add(FakeTx::new("tx3", &[5, 6])).unwrap();
mempool.acknowledge(&FakeTx::new("tx4", &[2, 5, 7]), FakeTx::keys);
assert_eq!(mempool.take(), vec![FakeTx::new("tx2", &[3, 4])]);
}
#[derive(Debug, PartialEq, Eq)]
struct FakeTx<'a> {
id: &'a str,
inputs: Vec<usize>,
}
impl cbor::Encode<()> for FakeTx<'_> {
fn encode<W: cbor::encode::Write>(
&self,
e: &mut cbor::Encoder<W>,
_ctx: &mut (),
) -> Result<(), cbor::encode::Error<W::Error>> {
e.encode(self.id)?;
e.encode(&self.inputs)?;
Ok(())
}
}
impl<'a> FakeTx<'a> {
fn new(id: &'a str, inputs: &'_ [usize]) -> Self {
FakeTx { id, inputs: Vec::from(inputs) }
}
fn keys(&self) -> Vec<usize> {
self.inputs.clone()
}
}
}