use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use rustrade_core::Symbol;
use tokio::sync::{Mutex, MutexGuard};
#[derive(Debug, Clone, Copy)]
pub(crate) struct PendingEntry {
pub notional: f64,
reserved_at: Instant,
}
#[derive(Debug, Default)]
pub(crate) struct PendingMap {
entries: HashMap<Symbol, PendingEntry>,
}
impl PendingMap {
fn expire_stale(&mut self, ttl: Duration, now: Instant) {
self.entries
.retain(|_, e| now.duration_since(e.reserved_at) < ttl);
}
pub(crate) fn contains(&self, symbol: &Symbol) -> bool {
self.entries.contains_key(symbol)
}
pub(crate) fn gross_notional(&self) -> f64 {
self.entries.values().map(|e| e.notional).sum()
}
pub(crate) fn new_slots(&self, mut is_open: impl FnMut(&Symbol) -> bool) -> u32 {
self.entries
.keys()
.filter(|s| !is_open(s))
.count()
.try_into()
.unwrap_or(u32::MAX)
}
pub(crate) fn reserve(&mut self, symbol: Symbol, notional: f64) {
let now = Instant::now();
self.entries
.entry(symbol)
.and_modify(|e| {
e.notional += notional;
e.reserved_at = now;
})
.or_insert(PendingEntry {
notional,
reserved_at: now,
});
}
pub(crate) fn release(&mut self, symbol: &Symbol) {
self.entries.remove(symbol);
}
}
#[derive(Debug, Clone)]
pub(crate) struct PendingEntryLedger {
inner: Arc<Mutex<PendingMap>>,
ttl: Duration,
}
impl PendingEntryLedger {
pub(crate) fn new(ttl: Duration) -> Self {
Self {
inner: Arc::new(Mutex::new(PendingMap::default())),
ttl,
}
}
pub(crate) async fn lock(&self) -> MutexGuard<'_, PendingMap> {
let mut guard = self.inner.lock().await;
guard.expire_stale(self.ttl, Instant::now());
guard
}
pub(crate) async fn release(&self, symbol: &Symbol) {
self.inner.lock().await.release(symbol);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sym(s: &str) -> Symbol {
Symbol::from(s)
}
#[tokio::test]
async fn reserve_release_roundtrip() {
let ledger = PendingEntryLedger::new(Duration::from_secs(30));
{
let mut m = ledger.lock().await;
m.reserve(sym("AAA"), 1_000.0);
m.reserve(sym("BBB"), 500.0);
assert!(m.contains(&sym("AAA")));
assert_eq!(m.gross_notional(), 1_500.0);
assert_eq!(m.new_slots(|_| false), 2);
assert_eq!(m.new_slots(|s| s == &sym("AAA")), 1);
}
ledger.release(&sym("AAA")).await;
let m = ledger.lock().await;
assert!(!m.contains(&sym("AAA")));
assert_eq!(m.gross_notional(), 500.0);
}
#[tokio::test]
async fn same_symbol_reservations_accumulate() {
let ledger = PendingEntryLedger::new(Duration::from_secs(30));
let mut m = ledger.lock().await;
m.reserve(sym("AAA"), 1_000.0);
m.reserve(sym("AAA"), 250.0);
assert_eq!(m.gross_notional(), 1_250.0);
assert_eq!(m.new_slots(|_| false), 1, "one symbol, one slot");
}
#[tokio::test]
async fn reservations_expire_after_ttl() {
let ledger = PendingEntryLedger::new(Duration::from_millis(20));
ledger.lock().await.reserve(sym("AAA"), 1_000.0);
tokio::time::sleep(Duration::from_millis(40)).await;
let m = ledger.lock().await;
assert!(
!m.contains(&sym("AAA")),
"stale reservation must expire on the next lock"
);
assert_eq!(m.gross_notional(), 0.0);
}
}