use crate::core::HasAccountId;
use crate::param::AccountId;
use crate::pretrade::{AccountBlock, Reject, RejectCode, RejectScope, Rejects};
use crate::storage::{self, IndexFlag, Storage, StorageBuilder};
pub struct AccountControl<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
handle: AccountBlockHandle<StorageFactory>,
account_id: AccountId,
}
impl<StorageFactory> Clone for AccountControl<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
fn clone(&self) -> Self {
Self {
handle: self.handle.clone(),
account_id: self.account_id,
}
}
}
impl<StorageFactory> AccountControl<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
pub(crate) fn new(handle: AccountBlockHandle<StorageFactory>, account_id: AccountId) -> Self {
Self { handle, account_id }
}
pub fn block(&self, block: AccountBlock) {
self.handle.record(self.account_id, block);
}
}
pub struct AccountBlockHandle<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
inner: StorageFactory::Shared<BlockedAccounts<StorageFactory>>,
}
impl<StorageFactory> Clone for AccountBlockHandle<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<StorageFactory> AccountBlockHandle<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
pub(crate) fn from_inner(
inner: StorageFactory::Shared<BlockedAccounts<StorageFactory>>,
) -> Self {
Self { inner }
}
pub(crate) fn record(&self, account_id: AccountId, block: AccountBlock) {
self.inner.block_account(account_id, block);
}
}
fn new_account_blocked_rejects() -> Rejects {
Rejects::new(vec![Reject::new(
"Engine",
RejectScope::Account,
RejectCode::AccountBlocked,
"account is blocked due to kill-switch",
"kill-switch was previously triggered for this account".to_owned(),
)])
}
fn new_unverifiable_blocked_rejects(scope: RejectScope) -> Rejects {
Rejects::new(vec![Reject::new(
"Engine",
scope,
RejectCode::MissingRequiredField,
"account could not be verified as account ID is missing",
"unable to check account for blocking".to_owned(),
)])
}
pub(crate) struct BlockedAccounts<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
any_flag: <StorageFactory as storage::LockingPolicyFactory>::IndexFlag,
all_flag: <StorageFactory as storage::LockingPolicyFactory>::IndexFlag,
accounts: Storage<AccountId, AccountBlock, StorageFactory::Policy>,
}
impl<StorageLockingPolicyFactory> BlockedAccounts<StorageLockingPolicyFactory>
where
StorageLockingPolicyFactory:
storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
pub(crate) fn new(builder: &StorageBuilder<StorageLockingPolicyFactory>) -> Self {
Self {
any_flag:
<StorageLockingPolicyFactory as storage::LockingPolicyFactory>::IndexFlag::new(
false,
),
all_flag:
<StorageLockingPolicyFactory as storage::LockingPolicyFactory>::IndexFlag::new(
false,
),
accounts: builder.create(),
}
}
pub(crate) fn check<Order: HasAccountId>(
&self,
order: &Order,
operation_scope: RejectScope,
) -> Option<Rejects> {
if !self.any_flag.load() {
debug_assert!(!self.all_flag.load());
return None;
}
match order.account_id() {
Err(_) => Some(new_unverifiable_blocked_rejects(operation_scope)),
Ok(id) => {
if let Some(rejects) = self
.accounts
.with(&id, |b| Rejects::new(vec![Reject::from(b.clone())]))
{
return Some(rejects);
}
if self.all_flag.load() {
return Some(new_account_blocked_rejects());
}
None
}
}
}
pub(crate) fn record<Report: HasAccountId>(&self, report: &Report, cause: AccountBlock) {
match report.account_id() {
Ok(id) => self.block_account(id, cause),
Err(_) => self.block_all(),
}
}
pub(crate) fn block_account(&self, id: AccountId, cause: AccountBlock) {
self.accounts.with_mut(id, || cause, |_, _| ());
self.any_flag.store(true);
}
fn block_all(&self) {
self.all_flag.store(true);
self.any_flag.store(true);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::HasAccountId;
use crate::param::AccountId;
use crate::pretrade::RejectCode;
use crate::storage::{NoLocking, StorageBuilder};
use crate::RequestFieldAccessError;
fn new_set() -> BlockedAccounts<NoLocking> {
BlockedAccounts::new(&StorageBuilder::new(NoLocking))
}
fn cause(policy: &str, code: RejectCode) -> AccountBlock {
AccountBlock::new(policy, code, "test block", "details")
}
struct AccountOrder(AccountId);
impl HasAccountId for AccountOrder {
fn account_id(&self) -> Result<AccountId, RequestFieldAccessError> {
Ok(self.0)
}
}
struct NoAccountOrder;
impl HasAccountId for NoAccountOrder {
fn account_id(&self) -> Result<AccountId, RequestFieldAccessError> {
Err(RequestFieldAccessError::new("account_id"))
}
}
fn account(id: u64) -> AccountId {
AccountId::from_u64(id)
}
#[test]
fn initially_nothing_blocked() {
let set = new_set();
assert!(set
.check(&AccountOrder(account(1)), RejectScope::Order)
.is_none());
assert!(set.check(&NoAccountOrder, RejectScope::Order).is_none());
}
#[test]
fn record_account_blocks_that_account() {
let set = new_set();
set.record(
&AccountOrder(account(1)),
cause("Policy", RejectCode::PnlKillSwitchTriggered),
);
assert!(set
.check(&AccountOrder(account(1)), RejectScope::Order)
.is_some());
}
#[test]
fn record_account_does_not_block_other_accounts() {
let set = new_set();
set.record(
&AccountOrder(account(1)),
cause("Policy", RejectCode::PnlKillSwitchTriggered),
);
assert!(set
.check(&AccountOrder(account(2)), RejectScope::Order)
.is_none());
}
#[test]
fn record_no_account_blocks_every_account() {
let set = new_set();
set.record(
&NoAccountOrder,
cause("Policy", RejectCode::PnlKillSwitchTriggered),
);
assert!(set
.check(&AccountOrder(account(1)), RejectScope::Order)
.is_some());
assert!(set
.check(&AccountOrder(account(99)), RejectScope::Order)
.is_some());
}
#[test]
fn record_no_account_blocks_unidentifiable_orders() {
let set = new_set();
set.record(
&NoAccountOrder,
cause("Policy", RejectCode::PnlKillSwitchTriggered),
);
assert!(set.check(&NoAccountOrder, RejectScope::Order).is_some());
}
#[test]
fn record_account_blocks_unidentifiable_orders() {
let set = new_set();
set.record(
&AccountOrder(account(1)),
cause("Policy", RejectCode::PnlKillSwitchTriggered),
);
assert!(set.check(&NoAccountOrder, RejectScope::Order).is_some());
}
#[test]
fn initially_unidentifiable_order_is_allowed() {
let set = new_set();
assert!(set.check(&NoAccountOrder, RejectScope::Order).is_none());
}
#[test]
fn check_returns_cause_for_blocked_account() {
let set = new_set();
set.record(
&AccountOrder(account(1)),
cause("KillSwitch", RejectCode::PnlKillSwitchTriggered),
);
let rejects = set
.check(&AccountOrder(account(1)), RejectScope::Order)
.expect("blocked account must return rejects");
assert_eq!(rejects.len(), 1);
assert_eq!(rejects[0].policy, "KillSwitch");
assert_eq!(rejects[0].code, RejectCode::PnlKillSwitchTriggered);
assert_eq!(rejects[0].scope, RejectScope::Account);
}
#[test]
fn first_cause_wins_on_repeated_block() {
let set = new_set();
set.record(
&AccountOrder(account(1)),
cause("First", RejectCode::PnlKillSwitchTriggered),
);
set.record(
&AccountOrder(account(1)),
cause("Second", RejectCode::Other),
);
let rejects = set
.check(&AccountOrder(account(1)), RejectScope::Order)
.expect("blocked account must return rejects");
assert_eq!(rejects[0].policy, "First");
}
}