use std::sync::Arc;
use std::thread;
use openpit::param::{AccountId, Asset, Fee, Pnl, Quantity, Side, TradeAmount};
use openpit::pretrade::policies::{
PnlBoundsAccountAssetBarrier, PnlBoundsBrokerBarrier, PnlBoundsKillSwitchPolicy,
PnlBoundsKillSwitchSettings,
};
use openpit::pretrade::{PostTradeContext, PreTradeContext, PreTradePolicy, RejectCode};
use openpit::storage::FullLocking;
use openpit::{Engine, FullSync, Instrument, OrderOperation, RequestFieldAccessError};
type TestPolicy = PnlBoundsKillSwitchPolicy<FullLocking>;
const TOTAL_THREADS: usize = 8;
const PER_THREAD_REPORTS: usize = 5;
const PNL_PER_REPORT: i64 = 10;
struct TestReport {
instrument: Instrument,
account_id: AccountId,
pnl: Pnl,
fee: Fee,
}
impl openpit::HasInstrument for TestReport {
fn instrument(&self) -> Result<&Instrument, RequestFieldAccessError> {
Ok(&self.instrument)
}
}
impl openpit::HasAccountId for TestReport {
fn account_id(&self) -> Result<AccountId, RequestFieldAccessError> {
Ok(self.account_id)
}
}
impl openpit::HasPnl for TestReport {
fn pnl(&self) -> Result<Pnl, RequestFieldAccessError> {
Ok(self.pnl)
}
}
impl openpit::HasFee for TestReport {
fn fee(&self) -> Result<Fee, RequestFieldAccessError> {
Ok(self.fee)
}
}
fn usd() -> Asset {
Asset::new("USD").expect("asset code must be valid")
}
fn account(id: u64) -> AccountId {
AccountId::from_u64(id)
}
fn pnl(s: &str) -> Pnl {
Pnl::from_str(s).expect("pnl literal must be valid")
}
fn build_order(account_id: AccountId) -> OrderOperation {
OrderOperation {
instrument: Instrument::new(Asset::new("AAPL").expect("asset code must be valid"), usd()),
account_id,
side: Side::Buy,
trade_amount: TradeAmount::Quantity(
Quantity::from_str("1").expect("quantity literal must be valid"),
),
price: None,
}
}
fn build_report(account_id: AccountId, pnl_val: Pnl) -> TestReport {
TestReport {
instrument: Instrument::new(Asset::new("AAPL").expect("asset code must be valid"), usd()),
account_id,
pnl: pnl_val,
fee: Fee::ZERO,
}
}
fn check_start(
policy: &TestPolicy,
order: &OrderOperation,
) -> Result<(), openpit::pretrade::Rejects> {
<TestPolicy as PreTradePolicy<OrderOperation, TestReport, (), FullSync>>::check_pre_trade_start(
policy,
&PreTradeContext::new(None),
order,
)
}
fn apply_report(policy: &TestPolicy, report: &TestReport) -> bool {
!<TestPolicy as PreTradePolicy<OrderOperation, TestReport, (), FullSync>>::apply_execution_report(
policy,
&PostTradeContext::<FullLocking>::new(),
report,
)
.map_or(true, |r| r.is_empty())
}
#[test]
fn pnl_full_sync_no_lost_updates_under_concurrent_apply() {
let expected_total = TOTAL_THREADS * PER_THREAD_REPORTS * (PNL_PER_REPORT as usize);
let upper_bound_str = expected_total.to_string();
let builder = Engine::builder::<OrderOperation, TestReport, ()>().full_sync();
let policy: Arc<TestPolicy> = Arc::new(PnlBoundsKillSwitchPolicy::new(
PnlBoundsKillSwitchSettings::new(
[PnlBoundsBrokerBarrier {
settlement_asset: usd(),
lower_bound: Some(pnl("-100000")),
upper_bound: Some(pnl("100000")),
}],
[PnlBoundsAccountAssetBarrier {
barrier: PnlBoundsBrokerBarrier {
settlement_asset: usd(),
lower_bound: Some(pnl("-100000")),
upper_bound: Some(
Pnl::from_str(&upper_bound_str).expect("upper bound must be valid"),
),
},
account_id: account(1),
initial_pnl: Pnl::ZERO,
}],
)
.expect("policy settings must be valid"),
builder.storage_builder(),
));
let pnl_delta = Pnl::from_str(&PNL_PER_REPORT.to_string()).expect("pnl delta must be valid");
thread::scope(|s| {
for _ in 0..TOTAL_THREADS {
let policy = Arc::clone(&policy);
let order = build_order(account(1));
s.spawn(move || {
for _ in 0..PER_THREAD_REPORTS {
apply_report(&policy, &build_report(account(1), pnl_delta));
check_start(&policy, &order)
.expect("check must pass: accumulated pnl within bounds during test");
}
});
}
});
let probe = build_report(account(1), pnl("1"));
let triggered = apply_report(&policy, &probe);
assert!(
triggered,
"probe report must breach upper bound {expected_total}: \
this failure means updates were lost under concurrent load"
);
}
#[test]
fn pnl_full_sync_kill_switch_is_monotonic_and_visible_to_subsequent_checks() {
let builder = Engine::builder::<OrderOperation, TestReport, ()>().full_sync();
let policy: Arc<TestPolicy> = Arc::new(PnlBoundsKillSwitchPolicy::new(
PnlBoundsKillSwitchSettings::new(
[PnlBoundsBrokerBarrier {
settlement_asset: usd(),
lower_bound: Some(pnl("-100000")),
upper_bound: Some(pnl("100000")),
}],
[PnlBoundsAccountAssetBarrier {
barrier: PnlBoundsBrokerBarrier {
settlement_asset: usd(),
lower_bound: Some(pnl("-100000")),
upper_bound: Some(pnl("50")),
},
account_id: account(1),
initial_pnl: Pnl::ZERO,
}],
)
.expect("policy settings must be valid"),
builder.storage_builder(),
));
let pnl_delta = Pnl::from_str(&PNL_PER_REPORT.to_string()).expect("pnl delta must be valid");
thread::scope(|s| {
for _ in 0..TOTAL_THREADS {
let policy = Arc::clone(&policy);
s.spawn(move || {
for _ in 0..PER_THREAD_REPORTS {
apply_report(&policy, &build_report(account(1), pnl_delta));
let _ = check_start(&policy, &build_order(account(1)));
}
});
}
});
let reject = check_start(&policy, &build_order(account(1)))
.expect_err("kill switch must be triggered after concurrent breach");
assert_eq!(
reject[0].code,
RejectCode::PnlKillSwitchTriggered,
"reject code must be PnlKillSwitchTriggered"
);
check_start(&policy, &build_order(account(2)))
.expect("account without triggered barrier must still accept");
}