use crate::{
book::protocol::command::{Command, CommandKind},
types::IdempotencyKey,
};
use async_trait::async_trait;
use dashmap::{DashMap, mapref::entry::Entry};
use futures::future::{BoxFuture, FutureExt};
use std::{
fmt,
hash::{Hash, Hasher},
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tracing::{trace, warn};
pub const DEFAULT_IDEMPOTENCY_TTL_SECS: u64 = 3 * 60 * 60;
pub const DEFAULT_DASHMAP_PRUNE_INTERVAL_SECS: u64 = 60 * 60;
pub const MIN_DASHMAP_PRUNE_INTERVAL_SECS: u64 = 1;
#[derive(Debug, Clone)]
pub struct IdempotencyConfig {
pub ttl: Duration,
}
impl Default for IdempotencyConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(DEFAULT_IDEMPOTENCY_TTL_SECS),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IdempotencyOutcome<R> {
Executed(R),
Replayed(R),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IdempotencyError<StoreErr, SubmitErr> {
Conflict,
InFlightElsewhere,
Store(StoreErr),
Submit(SubmitErr),
JoinDropped,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReserveOutcome<R> {
Reserved,
Replayed(R),
Conflict,
InFlight,
}
#[async_trait]
pub trait IdempotencyStore<R>: Send + Sync + 'static
where
R: Clone + Send + Sync + 'static,
{
type Error: Clone + Send + Sync + 'static;
async fn reserve(
&self,
key: &IdempotencyKey,
fingerprint: u64,
expires_at_ms: i64,
) -> Result<ReserveOutcome<R>, Self::Error>;
async fn complete(
&self,
key: &IdempotencyKey,
fingerprint: u64,
response: R,
expires_at_ms: i64,
) -> Result<(), Self::Error>;
async fn release_inflight(
&self,
key: &IdempotencyKey,
fingerprint: u64,
) -> Result<(), Self::Error>;
}
#[derive(Debug, Clone)]
pub struct DashMapIdempotencyStore<R>
where
R: Clone + Send + Sync + 'static,
{
entries: Arc<DashMap<IdempotencyKey, DashMapStoreEntry<R>>>,
prune_interval: Duration,
}
impl<R> Default for DashMapIdempotencyStore<R>
where
R: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<R> DashMapIdempotencyStore<R>
where
R: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self::with_prune_interval(Duration::from_secs(DEFAULT_DASHMAP_PRUNE_INTERVAL_SECS))
}
pub fn with_prune_interval(prune_interval: Duration) -> Self {
let prune_interval = normalize_prune_interval(prune_interval);
let entries = Arc::new(DashMap::new());
spawn_pruner(Arc::downgrade(&entries), prune_interval);
Self {
entries,
prune_interval,
}
}
pub fn prune_interval(&self) -> Duration {
self.prune_interval
}
}
#[derive(Debug, Clone)]
enum DashMapStoreEntry<R> {
InFlight {
fingerprint: u64,
expires_at_ms: i64,
},
Completed {
fingerprint: u64,
response: R,
expires_at_ms: i64,
},
}
fn spawn_pruner<R>(
entries: std::sync::Weak<DashMap<IdempotencyKey, DashMapStoreEntry<R>>>,
prune_interval: Duration,
) where
R: Clone + Send + Sync + 'static,
{
tokio::spawn(async move {
trace!(
prune_interval_ms = prune_interval.as_millis(),
"dashmap idempotency pruner started"
);
loop {
tokio::time::sleep(prune_interval).await;
let Some(entries) = entries.upgrade() else {
trace!("dashmap idempotency pruner stopping (store dropped)");
break;
};
let now_ms = now_unix_ms();
entries.retain(|_, entry| match entry {
DashMapStoreEntry::InFlight { .. } => true,
DashMapStoreEntry::Completed { expires_at_ms, .. } => *expires_at_ms > now_ms,
});
}
});
}
fn normalize_prune_interval(prune_interval: Duration) -> Duration {
if prune_interval.is_zero() {
warn!(
min_prune_interval_secs = MIN_DASHMAP_PRUNE_INTERVAL_SECS,
"dashmap idempotency prune interval of 0 is not allowed; clamping to minimum"
);
return Duration::from_secs(MIN_DASHMAP_PRUNE_INTERVAL_SECS);
}
prune_interval
}
fn recommended_prune_interval(ttl: Duration) -> Duration {
normalize_prune_interval(ttl.checked_div(3).unwrap_or(ttl))
}
#[async_trait]
impl<R> IdempotencyStore<R> for DashMapIdempotencyStore<R>
where
R: Clone + Send + Sync + 'static,
{
type Error = std::convert::Infallible;
async fn reserve(
&self,
key: &IdempotencyKey,
fingerprint: u64,
expires_at_ms: i64,
) -> Result<ReserveOutcome<R>, Self::Error> {
loop {
let now_ms = now_unix_ms();
match self.entries.entry(key.clone()) {
Entry::Vacant(v) => {
v.insert(DashMapStoreEntry::InFlight {
fingerprint,
expires_at_ms,
});
return Ok(ReserveOutcome::Reserved);
}
Entry::Occupied(o) => match o.get() {
DashMapStoreEntry::InFlight {
fingerprint: existing,
expires_at_ms: exp,
} => {
if *exp <= now_ms {
warn!(
idempotency_key = %key,
inflight_expires_at_ms = *exp,
now_ms,
"encountered expired in-flight idempotency entry; keeping lock until completion"
);
}
if *existing == fingerprint {
return Ok(ReserveOutcome::InFlight);
}
return Ok(ReserveOutcome::Conflict);
}
DashMapStoreEntry::Completed {
fingerprint: existing,
response,
expires_at_ms: exp,
} => {
if *exp <= now_ms {
o.remove();
continue;
}
if *existing == fingerprint {
return Ok(ReserveOutcome::Replayed(response.clone()));
}
return Ok(ReserveOutcome::Conflict);
}
},
}
}
}
async fn complete(
&self,
key: &IdempotencyKey,
fingerprint: u64,
response: R,
expires_at_ms: i64,
) -> Result<(), Self::Error> {
self.entries.insert(
key.clone(),
DashMapStoreEntry::Completed {
fingerprint,
response,
expires_at_ms,
},
);
Ok(())
}
async fn release_inflight(
&self,
key: &IdempotencyKey,
fingerprint: u64,
) -> Result<(), Self::Error> {
if let Entry::Occupied(o) = self.entries.entry(key.clone()) {
let remove = matches!(
o.get(),
DashMapStoreEntry::InFlight {
fingerprint: existing,
..
} if *existing == fingerprint
);
if remove {
o.remove();
}
}
Ok(())
}
}
pub trait FingerprintPolicy<C>: Send + Sync + Clone + 'static {
fn fingerprint(&self, payload: &C) -> u64;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BetexCommandFingerprint;
impl FingerprintPolicy<Command> for BetexCommandFingerprint {
fn fingerprint(&self, payload: &Command) -> u64 {
fingerprint_command(payload)
}
}
pub fn fingerprint_command(cmd: &Command) -> u64 {
const FP_CREATE_MARKET: u8 = 1;
const FP_PLACE_ORDER: u8 = 2;
const FP_PLACE_BINARY_ORDER: u8 = 3;
match &cmd.kind {
CommandKind::CreateMarket {
name,
market_model,
book_type,
market_kind,
market_state,
market_phase,
runner_ids,
runner_labels,
} => hash64(&(
FP_CREATE_MARKET,
cmd.market_id,
name,
market_model,
book_type,
market_kind,
market_state,
market_phase,
runner_ids,
runner_labels,
)),
CommandKind::PlaceOrder {
runner_id,
account_id,
side,
odds,
stake,
persistence,
time_in_force,
..
} => hash64(&(
FP_PLACE_ORDER,
cmd.market_id,
account_id.clone(),
*runner_id,
*side,
*odds,
*stake,
*persistence,
*time_in_force,
)),
CommandKind::PlaceBinaryOrder {
account_id,
side,
price_ticks,
qty_shares,
time_in_force,
..
} => hash64(&(
FP_PLACE_BINARY_ORDER,
cmd.market_id,
account_id.clone(),
*side,
*price_ticks,
*qty_shares,
*time_in_force,
)),
_ => hash64(&(cmd.market_id, &cmd.kind)),
}
}
fn hash64<T: Hash + ?Sized>(value: &T) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
value.hash(&mut hasher);
hasher.finish()
}
type SharedResult<R, StoreErr, SubmitErr> = futures::future::Shared<
BoxFuture<'static, Result<IdempotencyOutcome<R>, IdempotencyError<StoreErr, SubmitErr>>>,
>;
type InflightMap<R, StoreErr, SubmitErr> =
Arc<DashMap<IdempotencyKey, LocalInflightEntry<R, StoreErr, SubmitErr>>>;
#[derive(Clone)]
struct LocalInflightEntry<R, StoreErr, SubmitErr> {
fingerprint: u64,
shared: SharedResult<R, StoreErr, SubmitErr>,
}
impl<R, StoreErr, SubmitErr> fmt::Debug for LocalInflightEntry<R, StoreErr, SubmitErr> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalInflightEntry")
.field("fingerprint", &self.fingerprint)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct IdempotencyHelper<S, R, E, FP>
where
S: IdempotencyStore<R>,
R: Clone + Send + Sync + 'static,
E: Clone + Send + Sync + 'static,
FP: Clone + Send + Sync + 'static,
{
store: Arc<S>,
policy: FP,
config: IdempotencyConfig,
inflight: InflightMap<R, S::Error, E>,
}
impl<S, R, E, FP> IdempotencyHelper<S, R, E, FP>
where
S: IdempotencyStore<R>,
R: Clone + Send + Sync + 'static,
E: Clone + Send + Sync + 'static,
FP: Clone + Send + Sync + 'static,
{
pub fn new(store: S, policy: FP, config: IdempotencyConfig) -> Self {
Self {
store: Arc::new(store),
policy,
config,
inflight: Arc::new(DashMap::new()),
}
}
pub fn with_store(store: Arc<S>, policy: FP, config: IdempotencyConfig) -> Self {
Self {
store,
policy,
config,
inflight: Arc::new(DashMap::new()),
}
}
pub fn store(&self) -> Arc<S> {
Arc::clone(&self.store)
}
pub async fn execute<C, SubmitFn, SubmitFut>(
&self,
key: IdempotencyKey,
payload: C,
submit_fn: SubmitFn,
) -> Result<IdempotencyOutcome<R>, IdempotencyError<S::Error, E>>
where
FP: FingerprintPolicy<C>,
C: Send + 'static,
SubmitFn: FnOnce(C) -> SubmitFut + Send + 'static,
SubmitFut: std::future::Future<Output = Result<R, E>> + Send + 'static,
{
let fingerprint = self.policy.fingerprint(&payload);
self.execute_with_fingerprint(key, fingerprint, payload, submit_fn)
.await
}
pub async fn execute_with_fingerprint<C, SubmitFn, SubmitFut>(
&self,
key: IdempotencyKey,
fingerprint: u64,
payload: C,
submit_fn: SubmitFn,
) -> Result<IdempotencyOutcome<R>, IdempotencyError<S::Error, E>>
where
C: Send + 'static,
SubmitFn: FnOnce(C) -> SubmitFut + Send + 'static,
SubmitFut: std::future::Future<Output = Result<R, E>> + Send + 'static,
{
let now_ms = now_unix_ms();
let ttl_ms = i64::try_from(self.config.ttl.as_millis()).unwrap_or(i64::MAX);
let expires_at_ms = now_ms.saturating_add(ttl_ms);
let key_for_task = key.clone();
let inflight = Arc::clone(&self.inflight);
let store = Arc::clone(&self.store);
let shared: SharedResult<R, S::Error, E> = async move {
match store
.reserve(&key_for_task, fingerprint, expires_at_ms)
.await
.map_err(IdempotencyError::Store)?
{
ReserveOutcome::Replayed(response) => {
inflight.remove(&key_for_task);
Ok(IdempotencyOutcome::Replayed(response))
}
ReserveOutcome::Conflict => {
inflight.remove(&key_for_task);
Err(IdempotencyError::Conflict)
}
ReserveOutcome::InFlight => {
inflight.remove(&key_for_task);
Err(IdempotencyError::InFlightElsewhere)
}
ReserveOutcome::Reserved => {
let result = submit_fn(payload).await;
match result {
Ok(response) => {
let complete_result = store
.complete(
&key_for_task,
fingerprint,
response.clone(),
expires_at_ms,
)
.await;
inflight.remove(&key_for_task);
if let Err(err) = complete_result {
return Err(IdempotencyError::Store(err));
}
Ok(IdempotencyOutcome::Executed(response))
}
Err(err) => {
store
.release_inflight(&key_for_task, fingerprint)
.await
.ok();
inflight.remove(&key_for_task);
Err(IdempotencyError::Submit(err))
}
}
}
}
}
.boxed()
.shared();
let waiter = match self.inflight.entry(key) {
Entry::Occupied(existing) => {
let entry = existing.get();
if entry.fingerprint != fingerprint {
return Err(IdempotencyError::Conflict);
}
entry.shared.clone()
}
Entry::Vacant(v) => {
v.insert(LocalInflightEntry {
fingerprint,
shared: shared.clone(),
});
shared
}
};
waiter.await
}
}
impl<R, E, FP> IdempotencyHelper<DashMapIdempotencyStore<R>, R, E, FP>
where
R: Clone + Send + Sync + 'static,
E: Clone + Send + Sync + 'static,
FP: Clone + Send + Sync + 'static,
{
pub fn with_dashmap_store(policy: FP, config: IdempotencyConfig) -> Self {
let prune_interval = recommended_prune_interval(config.ttl);
Self::new(
DashMapIdempotencyStore::with_prune_interval(prune_interval),
policy,
config,
)
}
}
fn now_unix_ms() -> i64 {
let ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_millis(0))
.as_millis();
i64::try_from(ms).unwrap_or(i64::MAX)
}