use crate::orderbook::error::OrderBookError;
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use pricelevel::{Hash32, Id, PriceLevelSnapshot};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use tracing::warn;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum ReferencePriceSource {
LastTrade,
Mid,
FixedPrice(u128),
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RiskConfig {
pub max_open_orders_per_account: Option<u64>,
pub max_notional_per_account: Option<u128>,
pub price_band_bps: Option<u32>,
pub reference_price: Option<ReferencePriceSource>,
}
impl RiskConfig {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[inline]
#[must_use]
pub fn with_max_open_orders_per_account(mut self, n: u64) -> Self {
self.max_open_orders_per_account = Some(n);
self
}
#[inline]
#[must_use]
pub fn with_max_notional_per_account(mut self, n: u128) -> Self {
self.max_notional_per_account = Some(n);
self
}
#[inline]
#[must_use]
pub fn with_price_band_bps(mut self, bps: u32, source: ReferencePriceSource) -> Self {
self.price_band_bps = Some(bps);
self.reference_price = Some(source);
self
}
}
#[derive(Debug, Default)]
pub struct RiskCounters {
pub(super) open_count: AtomicU64,
pub(super) resting_notional: AtomicCell<u128>,
}
#[derive(Debug, Clone, Copy)]
pub(super) struct RiskEntry {
pub(super) account: Hash32,
pub(super) price: u128,
pub(super) remaining_qty: u64,
}
#[derive(Debug, Default)]
pub struct RiskState {
pub(super) config: Option<RiskConfig>,
pub(super) counters: DashMap<Hash32, RiskCounters>,
pub(super) orders: DashMap<Id, RiskEntry>,
pub(super) warned_no_reference: AtomicBool,
}
#[inline]
fn saturating_sub_u64(counter: &AtomicU64, delta: u64) {
let _ = counter.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(delta))
});
}
#[inline]
fn saturating_sub_u128(cell: &AtomicCell<u128>, delta: u128) {
let mut current = cell.load();
loop {
let new = current.saturating_sub(delta);
match cell.compare_exchange(current, new) {
Ok(_) => return,
Err(actual) => current = actual,
}
}
}
impl RiskState {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_config(&mut self, cfg: RiskConfig) {
self.config = Some(cfg);
self.warned_no_reference.store(false, Ordering::Relaxed);
}
#[inline]
#[must_use]
pub fn config(&self) -> Option<&RiskConfig> {
self.config.as_ref()
}
pub fn disable(&mut self) {
self.config = None;
}
#[inline]
pub(super) fn check_limit_admission(
&self,
account: Hash32,
price: u128,
quantity: u64,
reference_price: Option<u128>,
) -> Result<(), OrderBookError> {
let Some(cfg) = self.config.as_ref() else {
return Ok(());
};
if let Some(limit) = cfg.max_open_orders_per_account {
let current = self
.counters
.get(&account)
.map(|c| c.open_count.load(Ordering::Relaxed))
.unwrap_or(0);
if current >= limit {
return Err(OrderBookError::RiskMaxOpenOrders {
account,
current,
limit,
});
}
}
if let Some(limit) = cfg.max_notional_per_account {
let current = self
.counters
.get(&account)
.map(|c| c.resting_notional.load())
.unwrap_or(0);
let attempted = (quantity as u128).saturating_mul(price);
if current.saturating_add(attempted) > limit {
return Err(OrderBookError::RiskMaxNotional {
account,
current,
attempted,
limit,
});
}
}
if let (Some(bps_limit), Some(reference)) = (cfg.price_band_bps, reference_price) {
if reference > 0 {
let diff = price.abs_diff(reference);
let bps_u128 = diff.saturating_mul(10_000) / reference;
let deviation_bps = if bps_u128 > u128::from(u32::MAX) {
u32::MAX
} else {
bps_u128 as u32
};
if deviation_bps > bps_limit {
return Err(OrderBookError::RiskPriceBand {
submitted: price,
reference,
deviation_bps,
limit_bps: bps_limit,
});
}
}
} else if cfg.price_band_bps.is_some()
&& cfg.reference_price.is_some()
&& reference_price.is_none()
{
if self
.warned_no_reference
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
warn!(
"risk: price-band check configured but no reference price available; \
check skipped until a trade or two-sided book establishes a reference"
);
}
}
Ok(())
}
#[inline]
pub(super) fn check_market_admission(&self, _account: Hash32) -> Result<(), OrderBookError> {
Ok(())
}
pub(super) fn on_admission(
&self,
order_id: Id,
account: Hash32,
price: u128,
remaining_qty: u64,
) {
if self.config.is_none() {
return;
}
self.orders.insert(
order_id,
RiskEntry {
account,
price,
remaining_qty,
},
);
let counters = self.counters.entry(account).or_default();
counters.open_count.fetch_add(1, Ordering::Relaxed);
let notional_delta = (remaining_qty as u128).saturating_mul(price);
counters.resting_notional.fetch_add(notional_delta);
}
pub(super) fn on_fill(&self, maker_id: Id, filled_qty: u64, maker_price: u128) {
if self.config.is_none() {
return;
}
let (account, fully_filled) = {
let Some(mut entry) = self.orders.get_mut(&maker_id) else {
return;
};
let new_remaining = entry.remaining_qty.saturating_sub(filled_qty);
let account = entry.account;
entry.remaining_qty = new_remaining;
(account, new_remaining == 0)
};
let notional_delta = (filled_qty as u128).saturating_mul(maker_price);
if let Some(counters_ref) = self.counters.get(&account) {
saturating_sub_u128(&counters_ref.resting_notional, notional_delta);
if fully_filled {
saturating_sub_u64(&counters_ref.open_count, 1);
}
}
if fully_filled {
self.orders.remove(&maker_id);
}
}
pub(super) fn on_cancel(&self, order_id: Id) {
if self.config.is_none() {
return;
}
let Some((_, entry)) = self.orders.remove(&order_id) else {
return;
};
let notional_delta = (entry.remaining_qty as u128).saturating_mul(entry.price);
if let Some(counters_ref) = self.counters.get(&entry.account) {
saturating_sub_u64(&counters_ref.open_count, 1);
saturating_sub_u128(&counters_ref.resting_notional, notional_delta);
}
}
pub(super) fn rebuild_from_snapshot(
&self,
bids: &[PriceLevelSnapshot],
asks: &[PriceLevelSnapshot],
) {
self.orders.clear();
self.counters.clear();
for level in bids.iter().chain(asks.iter()) {
let price = level.price();
for order in level.orders() {
let account = order.user_id();
let remaining_qty = order
.visible_quantity()
.saturating_add(order.hidden_quantity());
self.orders.insert(
order.id(),
RiskEntry {
account,
price,
remaining_qty,
},
);
let counters = self.counters.entry(account).or_default();
counters.open_count.fetch_add(1, Ordering::Relaxed);
let notional_delta = (remaining_qty as u128).saturating_mul(price);
counters.resting_notional.fetch_add(notional_delta);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pricelevel::Id;
fn account(byte: u8) -> Hash32 {
Hash32::new([byte; 32])
}
#[test]
fn test_risk_config_builder() {
let cfg = RiskConfig::new()
.with_max_open_orders_per_account(5)
.with_max_notional_per_account(1_000_000)
.with_price_band_bps(500, ReferencePriceSource::LastTrade);
assert_eq!(cfg.max_open_orders_per_account, Some(5));
assert_eq!(cfg.max_notional_per_account, Some(1_000_000));
assert_eq!(cfg.price_band_bps, Some(500));
assert_eq!(cfg.reference_price, Some(ReferencePriceSource::LastTrade));
}
#[test]
fn test_risk_state_no_config_is_passthrough() {
let state = RiskState::new();
let acct = account(1);
let order_id = Id::new_uuid();
assert!(
state
.check_limit_admission(acct, 100, 10, Some(100))
.is_ok()
);
assert!(state.check_market_admission(acct).is_ok());
state.on_admission(order_id, acct, 100, 10);
state.on_fill(order_id, 5, 100);
state.on_cancel(order_id);
assert!(state.counters.is_empty());
assert!(state.orders.is_empty());
}
#[test]
fn test_on_admission_then_on_cancel_round_trip() {
let mut state = RiskState::new();
state.set_config(
RiskConfig::new()
.with_max_open_orders_per_account(10)
.with_max_notional_per_account(1_000_000),
);
let acct = account(2);
let order_id = Id::new_uuid();
state.on_admission(order_id, acct, 100, 10);
let counters = state
.counters
.get(&acct)
.expect("counters entry created on admission");
assert_eq!(counters.open_count.load(Ordering::Relaxed), 1);
assert_eq!(counters.resting_notional.load(), 1_000);
drop(counters);
state.on_cancel(order_id);
let counters = state
.counters
.get(&acct)
.expect("counters entry retained after cancel");
assert_eq!(counters.open_count.load(Ordering::Relaxed), 0);
assert_eq!(counters.resting_notional.load(), 0);
assert!(!state.orders.contains_key(&order_id));
}
#[test]
fn test_on_fill_partial_keeps_open_count() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_notional_per_account(1_000_000));
let acct = account(3);
let order_id = Id::new_uuid();
state.on_admission(order_id, acct, 100, 10);
state.on_fill(order_id, 4, 100);
let counters = state.counters.get(&acct).expect("counters entry present");
assert_eq!(
counters.open_count.load(Ordering::Relaxed),
1,
"partial fill must not drop open_count"
);
assert_eq!(
counters.resting_notional.load(),
6 * 100,
"notional must be reduced by filled_qty * price"
);
let entry = state
.orders
.get(&order_id)
.expect("entry retained after partial fill");
assert_eq!(entry.remaining_qty, 6);
}
#[test]
fn test_on_fill_full_decrements_open_count() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_open_orders_per_account(10));
let acct = account(4);
let order_id = Id::new_uuid();
state.on_admission(order_id, acct, 100, 10);
state.on_fill(order_id, 10, 100);
let counters = state.counters.get(&acct).expect("counters entry retained");
assert_eq!(counters.open_count.load(Ordering::Relaxed), 0);
assert_eq!(counters.resting_notional.load(), 0);
assert!(!state.orders.contains_key(&order_id));
}
#[test]
fn test_check_limit_admission_max_open_orders_breach_returns_typed_error() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_open_orders_per_account(2));
let acct = account(5);
state.on_admission(Id::new_uuid(), acct, 100, 1);
state.on_admission(Id::new_uuid(), acct, 100, 1);
let err = state
.check_limit_admission(acct, 100, 1, Some(100))
.expect_err("third admission must breach max_open_orders");
match err {
OrderBookError::RiskMaxOpenOrders {
account: a,
current,
limit,
} => {
assert_eq!(a, acct);
assert_eq!(current, 2);
assert_eq!(limit, 2);
}
other => panic!("expected RiskMaxOpenOrders, got {other:?}"),
}
}
#[test]
fn test_check_limit_admission_max_notional_breach_returns_typed_error() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_notional_per_account(1_000));
let acct = account(6);
state.on_admission(Id::new_uuid(), acct, 100, 8);
let err = state
.check_limit_admission(acct, 100, 3, Some(100))
.expect_err("notional should be exceeded");
match err {
OrderBookError::RiskMaxNotional {
account: a,
current,
attempted,
limit,
} => {
assert_eq!(a, acct);
assert_eq!(current, 800);
assert_eq!(attempted, 300);
assert_eq!(limit, 1_000);
}
other => panic!("expected RiskMaxNotional, got {other:?}"),
}
}
#[test]
fn test_check_limit_admission_price_band_breach_returns_typed_error() {
let mut state = RiskState::new();
state.set_config(
RiskConfig::new().with_price_band_bps(100, ReferencePriceSource::LastTrade),
);
let acct = account(7);
let err = state
.check_limit_admission(acct, 1_100_000, 1, Some(1_000_000))
.expect_err("price band should be exceeded");
match err {
OrderBookError::RiskPriceBand {
submitted,
reference,
deviation_bps,
limit_bps,
} => {
assert_eq!(submitted, 1_100_000);
assert_eq!(reference, 1_000_000);
assert_eq!(deviation_bps, 1_000); assert_eq!(limit_bps, 100);
}
other => panic!("expected RiskPriceBand, got {other:?}"),
}
}
#[test]
fn test_check_limit_admission_no_reference_price_skips_band_check() {
let mut state = RiskState::new();
state.set_config(
RiskConfig::new().with_price_band_bps(100, ReferencePriceSource::LastTrade),
);
assert!(
state
.check_limit_admission(account(8), 999_999_999, 1, None)
.is_ok()
);
}
#[test]
fn test_check_limit_admission_warns_only_once_when_no_reference_available() {
let mut state = RiskState::new();
state.set_config(
RiskConfig::new().with_price_band_bps(100, ReferencePriceSource::LastTrade),
);
let acct = account(9);
assert!(state.check_limit_admission(acct, 1, 1, None).is_ok());
assert!(
state.warned_no_reference.load(Ordering::Relaxed),
"first call without reference should flip the latch"
);
assert!(state.check_limit_admission(acct, 2, 2, None).is_ok());
assert!(state.warned_no_reference.load(Ordering::Relaxed));
}
#[test]
fn test_within_limits_admission_succeeds() {
let mut state = RiskState::new();
state.set_config(
RiskConfig::new()
.with_max_open_orders_per_account(10)
.with_max_notional_per_account(1_000_000)
.with_price_band_bps(500, ReferencePriceSource::LastTrade),
);
let acct = account(10);
assert!(state.check_limit_admission(acct, 100, 5, Some(100)).is_ok());
}
#[test]
fn test_disable_keeps_counters() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_open_orders_per_account(10));
let acct = account(11);
let order_id = Id::new_uuid();
state.on_admission(order_id, acct, 100, 10);
state.disable();
assert!(state.config().is_none());
assert!(state.counters.contains_key(&acct));
assert!(state.orders.contains_key(&order_id));
assert!(
state
.check_limit_admission(acct, 100, 100, Some(100))
.is_ok()
);
}
#[test]
fn test_on_fill_overshoot_clamps_counters_at_zero() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_notional_per_account(10_000));
let acct = account(12);
let order_id = Id::new_uuid();
state.on_admission(order_id, acct, 100, 5);
state.on_fill(order_id, 1_000_000, 100);
let counters = state.counters.get(&acct).expect("counters present");
assert_eq!(counters.open_count.load(Ordering::Relaxed), 0);
assert_eq!(counters.resting_notional.load(), 0);
}
#[test]
fn test_on_cancel_after_fully_filled_is_noop_and_does_not_wrap() {
let mut state = RiskState::new();
state.set_config(RiskConfig::new().with_max_open_orders_per_account(10));
let acct = account(13);
let order_id = Id::new_uuid();
state.on_admission(order_id, acct, 100, 5);
state.on_fill(order_id, 5, 100); state.on_cancel(order_id);
let counters = state.counters.get(&acct).expect("counters present");
assert_eq!(counters.open_count.load(Ordering::Relaxed), 0);
assert_eq!(counters.resting_notional.load(), 0);
}
}