use std::{
collections::{BTreeMap, BTreeSet},
mem,
pin::Pin,
sync::Arc,
};
use amaru_kernel::{cbor, to_cbor};
use amaru_ouroboros_traits::{
CanValidateTransactions, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
TxSubmissionMempool, mempool::Mempool,
};
use tokio::sync::Notify;
#[derive(Clone)]
pub struct InMemoryMempool<Tx> {
config: MempoolConfig,
inner: Arc<parking_lot::RwLock<MempoolInner<Tx>>>,
tx_validator: Arc<dyn CanValidateTransactions<Tx>>,
}
impl<Tx> Default for InMemoryMempool<Tx> {
fn default() -> Self {
Self::from_config(MempoolConfig::default())
}
}
impl<Tx> InMemoryMempool<Tx> {
pub fn new(config: MempoolConfig, tx_validator: Arc<dyn CanValidateTransactions<Tx>>) -> Self {
InMemoryMempool { config, inner: Arc::new(parking_lot::RwLock::new(MempoolInner::default())), tx_validator }
}
pub fn from_config(config: MempoolConfig) -> Self {
Self::new(config, Arc::new(DefaultCanValidateTransactions))
}
}
#[derive(Clone, Debug, Default)]
pub struct DefaultCanValidateTransactions;
impl<Tx> CanValidateTransactions<Tx> for DefaultCanValidateTransactions {
fn validate_transaction(&self, _tx: Tx) -> Result<(), TransactionValidationError> {
Ok(())
}
}
#[derive(Debug)]
struct MempoolInner<Tx> {
next_seq: u64,
entries_by_id: BTreeMap<TxId, MempoolEntry<Tx>>,
entries_by_seq: BTreeMap<MempoolSeqNo, TxId>,
notify: Arc<Notify>,
}
impl<Tx> Default for MempoolInner<Tx> {
fn default() -> Self {
MempoolInner {
next_seq: 1,
entries_by_id: Default::default(),
entries_by_seq: Default::default(),
notify: Arc::new(Notify::new()),
}
}
}
impl<Tx: cbor::Encode<()> + Clone> MempoolInner<Tx> {
fn insert(
&mut self,
config: &MempoolConfig,
tx: Tx,
tx_origin: TxOrigin,
) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
if let Some(max_txs) = config.max_txs
&& self.entries_by_id.len() >= max_txs
{
return Err(TxRejectReason::MempoolFull);
}
let tx_id = TxId::from(&tx);
if self.entries_by_id.contains_key(&tx_id) {
return Err(TxRejectReason::Duplicate);
}
let seq_no = MempoolSeqNo(self.next_seq);
self.next_seq += 1;
let tx_size = to_cbor(&tx).len() as u32;
let entry = MempoolEntry { seq_no, tx_id, tx, tx_size, origin: tx_origin };
self.entries_by_id.insert(tx_id, entry);
self.entries_by_seq.insert(seq_no, tx_id);
Ok((tx_id, seq_no))
}
fn get_tx(&self, tx_id: &TxId) -> Option<Tx> {
self.entries_by_id.get(tx_id).map(|entry| entry.tx.clone())
}
fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
let mut result: Vec<(TxId, u32, MempoolSeqNo)> = self
.entries_by_seq
.range(from_seq..)
.take(limit as usize)
.map(|(seq, tx_id)| {
let Some(entry) = self.entries_by_id.get(tx_id) else {
panic!("Inconsistent mempool state: entry missing for tx_id {:?}", tx_id)
};
(*tx_id, entry.tx_size, *seq)
})
.collect();
result.sort_by_key(|(_, _, seq_no)| *seq_no);
result
}
fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Tx> {
let mut result: Vec<(&TxId, &MempoolEntry<Tx>)> =
self.entries_by_id.iter().filter(|(key, _)| ids.contains(*key)).collect();
result.sort_by_key(|(_, entry)| entry.seq_no);
result.into_iter().map(|(_, entry)| entry.tx.clone()).collect()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct MempoolEntry<Tx> {
seq_no: MempoolSeqNo,
tx_id: TxId,
tx: Tx,
tx_size: u32,
origin: TxOrigin,
}
#[derive(Debug, Clone, Default)]
pub struct MempoolConfig {
max_txs: Option<usize>,
}
impl MempoolConfig {
pub fn with_max_txs(mut self, max: usize) -> Self {
self.max_txs = Some(max);
self
}
}
impl<Tx: Send + Sync + 'static> CanValidateTransactions<Tx> for InMemoryMempool<Tx> {
fn validate_transaction(&self, tx: Tx) -> Result<(), TransactionValidationError> {
self.tx_validator.validate_transaction(tx)
}
}
impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> TxSubmissionMempool<Tx> for InMemoryMempool<Tx> {
fn insert(&self, tx: Tx, tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
let mut inner = self.inner.write();
let res = inner.insert(&self.config, tx, tx_origin);
if res.is_ok() {
inner.notify.notify_waiters();
}
res
}
fn get_tx(&self, tx_id: &TxId) -> Option<Tx> {
self.inner.read().get_tx(tx_id)
}
fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
self.inner.read().tx_ids_since(from_seq, limit)
}
fn wait_for_at_least(&self, seq_no: MempoolSeqNo) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
Box::pin(async move {
loop {
let (current_next_seq, notify) = {
let inner = self.inner.read();
(inner.next_seq, inner.notify.clone())
};
let notified = notify.notified();
if current_next_seq >= seq_no.0 {
return true;
}
notified.await;
}
})
}
fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Tx> {
self.inner.read().get_txs_for_ids(ids)
}
fn last_seq_no(&self) -> MempoolSeqNo {
MempoolSeqNo(self.inner.read().next_seq - 1)
}
}
impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> Mempool<Tx> for InMemoryMempool<Tx> {
fn take(&self) -> Vec<Tx> {
let mut inner = self.inner.write();
let entries = mem::take(&mut inner.entries_by_id);
let _ = mem::take(&mut inner.entries_by_seq);
entries.into_values().map(|entry| entry.tx).collect()
}
fn acknowledge<TxKey: Ord, I>(&self, tx: &Tx, keys: fn(&Tx) -> I)
where
I: IntoIterator<Item = TxKey>,
Self: Sized,
{
let keys_to_remove = BTreeSet::from_iter(keys(tx));
let mut inner = self.inner.write();
let seq_nos_to_remove: Vec<MempoolSeqNo> = inner
.entries_by_id
.values()
.filter(|entry| keys(&entry.tx).into_iter().any(|k| keys_to_remove.contains(&k)))
.map(|entry| entry.seq_no)
.collect();
inner.entries_by_id.retain(|_, entry| !keys(&entry.tx).into_iter().any(|k| keys_to_remove.contains(&k)));
for seq_no in seq_nos_to_remove {
inner.entries_by_seq.remove(&seq_no);
}
}
}
#[cfg(test)]
mod tests {
use std::{ops::Deref, slice, str::FromStr, time::Duration};
use amaru_kernel::{Peer, cbor, cbor as minicbor};
use assertables::assert_some_eq_x;
use tokio::time::timeout;
use super::*;
#[tokio::test]
async fn insert_a_transaction() -> anyhow::Result<()> {
let mempool = InMemoryMempool::from_config(MempoolConfig::default().with_max_txs(5));
let tx = Tx::from_str("tx1").unwrap();
let (tx_id, seq_nb) = mempool.insert(tx.clone(), TxOrigin::Remote(Peer::new("upstream"))).unwrap();
assert_some_eq_x!(mempool.get_tx(&tx_id), tx.clone());
assert_eq!(mempool.get_txs_for_ids(slice::from_ref(&tx_id)), vec![tx.clone()]);
assert_eq!(mempool.tx_ids_since(seq_nb, 100), vec![(tx_id, 5, seq_nb)]);
assert!(mempool.wait_for_at_least(seq_nb).await, "should have at least seq no");
assert!(
timeout(Duration::from_millis(100), mempool.wait_for_at_least(seq_nb.add(100))).await.is_err(),
"should timeout waiting for a seq no that is too high"
);
Ok(())
}
#[derive(Debug, PartialEq, Eq, Clone, cbor::Encode, cbor::Decode)]
struct Tx(#[n(0)] String);
impl Deref for Tx {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl FromStr for Tx {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Tx(s.to_string()))
}
}
}