use std::sync::Arc;
use std::thread;
use std::time::Duration;
use openpit::param::{AccountId, Asset, Quantity, Side, TradeAmount};
use openpit::pretrade::policies::{
RateLimit, RateLimitBrokerBarrier, RateLimitPolicy, RateLimitSettings,
};
use openpit::pretrade::{PreTradeContext, PreTradePolicy};
use openpit::storage::FullLocking;
use openpit::{Engine, FullSync, Instrument, OrderOperation};
type TestPolicy = RateLimitPolicy<FullLocking>;
const TOTAL_THREADS: usize = 8;
const PER_THREAD: usize = 1_000;
fn build_order(account_id: AccountId) -> OrderOperation {
OrderOperation {
instrument: Instrument::new(
Asset::new("AAPL").expect("asset code must be valid"),
Asset::new("USD").expect("asset code must be valid"),
),
account_id,
side: Side::Buy,
trade_amount: TradeAmount::Quantity(
Quantity::from_str("1").expect("quantity literal must be valid"),
),
price: None,
}
}
#[test]
fn rate_limit_full_sync_broker_counter_not_lost_under_concurrent_load() {
let total_calls = TOTAL_THREADS * PER_THREAD;
let builder = Engine::builder::<OrderOperation, (), ()>().full_sync();
let policy: Arc<TestPolicy> = Arc::new(RateLimitPolicy::<FullLocking>::new(
RateLimitSettings::new(
Some(RateLimitBrokerBarrier {
limit: RateLimit {
max_orders: total_calls,
window: Duration::from_secs(60),
},
}),
[],
[],
[],
)
.expect("rate limit settings must be valid"),
builder.storage_builder(),
));
thread::scope(|s| {
for tid in 0..TOTAL_THREADS {
let policy = Arc::clone(&policy);
s.spawn(move || {
let order = build_order(AccountId::from_u64(tid as u64));
for _ in 0..PER_THREAD {
<TestPolicy as PreTradePolicy<OrderOperation, (), (), FullSync>>::check_pre_trade_start(
&policy,
&PreTradeContext::new(None),
&order,
)
.expect("all calls within limit must pass");
}
});
}
});
let overflow_order = build_order(AccountId::from_u64(99));
<TestPolicy as PreTradePolicy<OrderOperation, (), (), FullSync>>::check_pre_trade_start(
&policy,
&PreTradeContext::new(None),
&overflow_order,
)
.expect_err("call after exhausting limit must be rejected");
}