use pricelevel::{Hash32, Id};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[repr(u8)]
pub enum STPMode {
#[default]
None = 0,
CancelTaker = 1,
CancelMaker = 2,
CancelBoth = 3,
}
impl std::fmt::Display for STPMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
STPMode::None => write!(f, "None"),
STPMode::CancelTaker => write!(f, "CancelTaker"),
STPMode::CancelMaker => write!(f, "CancelMaker"),
STPMode::CancelBoth => write!(f, "CancelBoth"),
}
}
}
impl STPMode {
#[must_use]
#[inline]
pub fn is_enabled(self) -> bool {
self != STPMode::None
}
}
#[derive(Debug, Clone)]
pub(crate) enum STPAction {
NoConflict,
CancelTaker {
safe_quantity: u64,
},
CancelMaker {
maker_order_ids: Vec<Id>,
},
CancelBoth {
safe_quantity: u64,
maker_order_id: Id,
},
}
#[inline]
pub(crate) fn check_stp_at_level(
orders: &[std::sync::Arc<pricelevel::OrderType<()>>],
taker_user_id: Hash32,
mode: STPMode,
) -> STPAction {
if mode == STPMode::None || taker_user_id == Hash32::zero() {
return STPAction::NoConflict;
}
match mode {
STPMode::None => STPAction::NoConflict,
STPMode::CancelTaker => {
let mut safe_quantity: u64 = 0;
for order in orders {
if order.user_id() == taker_user_id {
return STPAction::CancelTaker { safe_quantity };
}
safe_quantity = safe_quantity.saturating_add(order.visible_quantity());
}
STPAction::NoConflict
}
STPMode::CancelMaker => {
let maker_order_ids: Vec<Id> = orders
.iter()
.filter(|o| o.user_id() == taker_user_id)
.map(|o| o.id())
.collect();
if maker_order_ids.is_empty() {
STPAction::NoConflict
} else {
STPAction::CancelMaker { maker_order_ids }
}
}
STPMode::CancelBoth => {
let mut safe_quantity: u64 = 0;
for order in orders {
if order.user_id() == taker_user_id {
return STPAction::CancelBoth {
safe_quantity,
maker_order_id: order.id(),
};
}
safe_quantity = safe_quantity.saturating_add(order.visible_quantity());
}
STPAction::NoConflict
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stp_mode_default_is_none() {
assert_eq!(STPMode::default(), STPMode::None);
}
#[test]
fn test_stp_mode_is_enabled() {
assert!(!STPMode::None.is_enabled());
assert!(STPMode::CancelTaker.is_enabled());
assert!(STPMode::CancelMaker.is_enabled());
assert!(STPMode::CancelBoth.is_enabled());
}
#[test]
fn test_stp_mode_display() {
assert_eq!(STPMode::None.to_string(), "None");
assert_eq!(STPMode::CancelTaker.to_string(), "CancelTaker");
assert_eq!(STPMode::CancelMaker.to_string(), "CancelMaker");
assert_eq!(STPMode::CancelBoth.to_string(), "CancelBoth");
}
#[test]
fn test_check_stp_none_mode_returns_no_conflict() {
let orders = vec![];
let action = check_stp_at_level(&orders, Hash32::zero(), STPMode::None);
assert!(matches!(action, STPAction::NoConflict));
}
#[test]
fn test_check_stp_zero_user_bypasses() {
let user = Hash32::zero();
let order = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(10),
side: pricelevel::Side::Sell,
user_id: user,
timestamp: pricelevel::TimestampMs::new(0),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let orders = vec![order];
let action = check_stp_at_level(&orders, user, STPMode::CancelTaker);
assert!(matches!(action, STPAction::NoConflict));
}
#[test]
fn test_check_stp_cancel_taker_detects_same_user() {
let user = Hash32::new([1u8; 32]);
let order = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(10),
side: pricelevel::Side::Sell,
user_id: user,
timestamp: pricelevel::TimestampMs::new(0),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let orders = vec![order];
let action = check_stp_at_level(&orders, user, STPMode::CancelTaker);
match action {
STPAction::CancelTaker { safe_quantity } => assert_eq!(safe_quantity, 0),
_ => panic!("expected CancelTaker action"),
}
}
#[test]
fn test_check_stp_cancel_taker_safe_quantity_before_self() {
let taker_user = Hash32::new([1u8; 32]);
let other_user = Hash32::new([2u8; 32]);
let other_order = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(5),
side: pricelevel::Side::Sell,
user_id: other_user,
timestamp: pricelevel::TimestampMs::new(0),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let same_order = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(10),
side: pricelevel::Side::Sell,
user_id: taker_user,
timestamp: pricelevel::TimestampMs::new(1),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let orders = vec![other_order, same_order];
let action = check_stp_at_level(&orders, taker_user, STPMode::CancelTaker);
match action {
STPAction::CancelTaker { safe_quantity } => assert_eq!(safe_quantity, 5),
_ => panic!("expected CancelTaker action"),
}
}
#[test]
fn test_check_stp_cancel_maker_collects_ids() {
let taker_user = Hash32::new([1u8; 32]);
let other_user = Hash32::new([2u8; 32]);
let same1 = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(5),
side: pricelevel::Side::Sell,
user_id: taker_user,
timestamp: pricelevel::TimestampMs::new(0),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let other = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(3),
side: pricelevel::Side::Sell,
user_id: other_user,
timestamp: pricelevel::TimestampMs::new(1),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let same2 = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(7),
side: pricelevel::Side::Sell,
user_id: taker_user,
timestamp: pricelevel::TimestampMs::new(2),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let orders = vec![same1.clone(), other, same2.clone()];
let action = check_stp_at_level(&orders, taker_user, STPMode::CancelMaker);
match action {
STPAction::CancelMaker { maker_order_ids } => {
assert_eq!(maker_order_ids.len(), 2);
assert_eq!(maker_order_ids[0], same1.id());
assert_eq!(maker_order_ids[1], same2.id());
}
_ => panic!("expected CancelMaker action"),
}
}
#[test]
fn test_check_stp_cancel_both_detects_self() {
let user = Hash32::new([1u8; 32]);
let other_user = Hash32::new([2u8; 32]);
let other = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(3),
side: pricelevel::Side::Sell,
user_id: other_user,
timestamp: pricelevel::TimestampMs::new(0),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let same = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(10),
side: pricelevel::Side::Sell,
user_id: user,
timestamp: pricelevel::TimestampMs::new(1),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let orders = vec![other, same.clone()];
let action = check_stp_at_level(&orders, user, STPMode::CancelBoth);
match action {
STPAction::CancelBoth {
safe_quantity,
maker_order_id,
} => {
assert_eq!(safe_quantity, 3);
assert_eq!(maker_order_id, same.id());
}
_ => panic!("expected CancelBoth action"),
}
}
#[test]
fn test_check_stp_no_conflict_when_different_users() {
let taker_user = Hash32::new([1u8; 32]);
let other_user = Hash32::new([2u8; 32]);
let order = std::sync::Arc::new(pricelevel::OrderType::Standard {
id: Id::new(),
price: pricelevel::Price::new(100),
quantity: pricelevel::Quantity::new(10),
side: pricelevel::Side::Sell,
user_id: other_user,
timestamp: pricelevel::TimestampMs::new(0),
time_in_force: pricelevel::TimeInForce::Gtc,
extra_fields: (),
});
let orders = vec![order];
assert!(matches!(
check_stp_at_level(&orders, taker_user, STPMode::CancelTaker),
STPAction::NoConflict
));
assert!(matches!(
check_stp_at_level(&orders, taker_user, STPMode::CancelMaker),
STPAction::NoConflict
));
assert!(matches!(
check_stp_at_level(&orders, taker_user, STPMode::CancelBoth),
STPAction::NoConflict
));
}
}