use std::cell::OnceCell;
use std::fmt::{Display, Formatter};
use crate::param::{AccountGroupId, AccountId, DEFAULT_ACCOUNT_GROUP};
use crate::storage::{self, LockingPolicy, Storage, StorageBuilder};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum AccountGroupError {
ReservedGroup,
AlreadyRegistered {
account: AccountId,
current_group: AccountGroupId,
},
NotInGroup {
account: AccountId,
requested_group: AccountGroupId,
current_group: Option<AccountGroupId>,
},
}
impl Display for AccountGroupError {
fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::ReservedGroup => {
formatter.write_str("the reserved default account group is not a valid target")
}
Self::AlreadyRegistered {
account,
current_group,
} => write!(
formatter,
"account {account} is already registered in group {current_group}"
),
Self::NotInGroup {
account,
requested_group,
current_group: Some(current),
} => write!(
formatter,
"account {account} is not in group {requested_group}; \
it belongs to group {current}"
),
Self::NotInGroup {
account,
requested_group,
current_group: None,
} => write!(
formatter,
"account {account} is not in group {requested_group}; \
it belongs to no group"
),
}
}
}
impl std::error::Error for AccountGroupError {}
pub(crate) struct AccountGroups<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
guard: <StorageFactory as storage::LockingPolicyFactory>::Policy,
memberships: Storage<AccountId, AccountGroupId, StorageFactory::Policy>,
}
impl<StorageFactory> AccountGroups<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
pub(crate) fn new(builder: &StorageBuilder<StorageFactory>) -> Self {
Self {
guard: builder.create_policy(),
memberships: builder.create_for_bound_key(),
}
}
pub(crate) fn register_group(
&self,
accounts: &[AccountId],
group: AccountGroupId,
) -> Result<(), AccountGroupError> {
if group == DEFAULT_ACCOUNT_GROUP {
return Err(AccountGroupError::ReservedGroup);
}
let _guard = self.guard.write_index();
for account in accounts {
if let Some(current_group) = self.memberships.with(account, |group| *group) {
return Err(AccountGroupError::AlreadyRegistered {
account: *account,
current_group,
});
}
}
for account in accounts {
self.memberships.with_mut(
*account,
|| group,
|slot, _| {
*slot = group;
},
);
}
Ok(())
}
pub(crate) fn unregister_group(
&self,
accounts: &[AccountId],
group: AccountGroupId,
) -> Result<(), AccountGroupError> {
if group == DEFAULT_ACCOUNT_GROUP {
return Err(AccountGroupError::ReservedGroup);
}
let _guard = self.guard.write_index();
for account in accounts {
let current_group = self.memberships.with(account, |group| *group);
if current_group != Some(group) {
return Err(AccountGroupError::NotInGroup {
account: *account,
requested_group: group,
current_group,
});
}
}
for account in accounts {
self.memberships.remove(account);
}
Ok(())
}
pub(crate) fn group_of(&self, account: AccountId) -> Option<AccountGroupId> {
let _guard = self.guard.read_index();
self.memberships.with(&account, |group| *group)
}
}
pub(crate) struct AccountGroupsHandle<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
inner: StorageFactory::Shared<AccountGroups<StorageFactory>>,
}
impl<StorageFactory> Clone for AccountGroupsHandle<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<StorageFactory> AccountGroupsHandle<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
pub(crate) fn from_inner(inner: StorageFactory::Shared<AccountGroups<StorageFactory>>) -> Self {
Self { inner }
}
pub(crate) fn register_group(
&self,
accounts: &[AccountId],
group: AccountGroupId,
) -> Result<(), AccountGroupError> {
self.inner.register_group(accounts, group)
}
pub(crate) fn unregister_group(
&self,
accounts: &[AccountId],
group: AccountGroupId,
) -> Result<(), AccountGroupError> {
self.inner.unregister_group(accounts, group)
}
pub(crate) fn group_of(&self, account: AccountId) -> Option<AccountGroupId> {
self.inner.group_of(account)
}
}
pub(crate) struct GroupLookup<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
handle: AccountGroupsHandle<StorageFactory>,
account: Option<AccountId>,
cached_group: OnceCell<Option<AccountGroupId>>,
}
impl<StorageFactory> GroupLookup<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
pub(crate) fn new(
handle: AccountGroupsHandle<StorageFactory>,
account: Option<AccountId>,
) -> Self {
Self {
handle,
account,
cached_group: OnceCell::new(),
}
}
pub(crate) fn group(&self) -> Option<AccountGroupId> {
*self.cached_group.get_or_init(|| {
self.account
.and_then(|account| self.handle.group_of(account))
})
}
}
impl<StorageFactory> crate::marketdata::AccountInfo for GroupLookup<StorageFactory>
where
StorageFactory: storage::LockingPolicyFactory + storage::CreateStorageFor<AccountId> + 'static,
{
fn group(&self) -> Option<AccountGroupId> {
self.group()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::param::{AccountGroupId, AccountId};
use crate::storage::{LockingPolicyFactory, NoLocking, StorageBuilder};
fn new_registry() -> AccountGroups<NoLocking> {
AccountGroups::new(&StorageBuilder::new(NoLocking))
}
fn account(id: u64) -> AccountId {
AccountId::from_u64(id)
}
fn group(id: u32) -> AccountGroupId {
AccountGroupId::from_u32(id).expect("account group id must be valid")
}
#[test]
fn register_group_happy_path() {
let registry = new_registry();
registry
.register_group(&[account(1), account(2)], group(7))
.expect("registration must succeed");
assert_eq!(registry.group_of(account(1)), Some(group(7)));
assert_eq!(registry.group_of(account(2)), Some(group(7)));
}
#[test]
fn register_group_rejects_account_in_other_group_and_changes_nothing() {
let registry = new_registry();
registry
.register_group(&[account(1)], group(1))
.expect("first registration must succeed");
let error = registry
.register_group(&[account(2), account(1)], group(2))
.expect_err("registration must fail on conflict");
assert_eq!(
error,
AccountGroupError::AlreadyRegistered {
account: account(1),
current_group: group(1),
}
);
assert_eq!(registry.group_of(account(2)), None);
assert_eq!(registry.group_of(account(1)), Some(group(1)));
}
#[test]
fn register_group_rejects_account_already_in_same_group_and_changes_nothing() {
let registry = new_registry();
registry
.register_group(&[account(1)], group(5))
.expect("first registration must succeed");
let error = registry
.register_group(&[account(2), account(1)], group(5))
.expect_err("re-registering into the same group must fail");
assert_eq!(
error,
AccountGroupError::AlreadyRegistered {
account: account(1),
current_group: group(5),
}
);
assert_eq!(registry.group_of(account(2)), None);
}
#[test]
fn unregister_group_happy_path() {
let registry = new_registry();
registry
.register_group(&[account(1), account(2)], group(3))
.expect("registration must succeed");
registry
.unregister_group(&[account(1), account(2)], group(3))
.expect("unregistration must succeed");
assert_eq!(registry.group_of(account(1)), None);
assert_eq!(registry.group_of(account(2)), None);
}
#[test]
fn unregister_group_rejects_ungrouped_account_and_removes_nothing() {
let registry = new_registry();
registry
.register_group(&[account(1)], group(3))
.expect("registration must succeed");
let error = registry
.unregister_group(&[account(1), account(2)], group(3))
.expect_err("unregistration must fail when an account is ungrouped");
assert_eq!(
error,
AccountGroupError::NotInGroup {
account: account(2),
requested_group: group(3),
current_group: None,
}
);
assert_eq!(registry.group_of(account(1)), Some(group(3)));
}
#[test]
fn unregister_group_rejects_account_in_other_group_and_removes_nothing() {
let registry = new_registry();
registry
.register_group(&[account(1)], group(3))
.expect("registration must succeed");
registry
.register_group(&[account(2)], group(4))
.expect("registration must succeed");
let error = registry
.unregister_group(&[account(1), account(2)], group(3))
.expect_err("unregistration must fail on group mismatch");
assert_eq!(
error,
AccountGroupError::NotInGroup {
account: account(2),
requested_group: group(3),
current_group: Some(group(4)),
}
);
assert_eq!(registry.group_of(account(1)), Some(group(3)));
assert_eq!(registry.group_of(account(2)), Some(group(4)));
}
#[test]
fn group_of_present_and_absent() {
let registry = new_registry();
registry
.register_group(&[account(1)], group(9))
.expect("registration must succeed");
assert_eq!(registry.group_of(account(1)), Some(group(9)));
assert_eq!(registry.group_of(account(2)), None);
}
#[test]
fn register_group_empty_slice_is_noop() {
let registry = new_registry();
registry
.register_group(&[], group(1))
.expect("empty registration must succeed");
assert_eq!(registry.group_of(account(1)), None);
}
#[test]
fn register_group_rejects_reserved_default_group() {
let registry = new_registry();
let error = registry
.register_group(&[account(1)], DEFAULT_ACCOUNT_GROUP)
.expect_err("registering into the default group must fail");
assert_eq!(error, AccountGroupError::ReservedGroup);
assert_eq!(registry.group_of(account(1)), None);
}
#[test]
fn unregister_group_rejects_reserved_default_group() {
let registry = new_registry();
let error = registry
.unregister_group(&[account(1)], DEFAULT_ACCOUNT_GROUP)
.expect_err("unregistering from the default group must fail");
assert_eq!(error, AccountGroupError::ReservedGroup);
}
#[test]
fn account_group_error_display_is_stable() {
assert_eq!(
AccountGroupError::ReservedGroup.to_string(),
"the reserved default account group is not a valid target"
);
let already = AccountGroupError::AlreadyRegistered {
account: account(1),
current_group: group(2),
};
assert_eq!(
already.to_string(),
"account 1 is already registered in group 2"
);
let mismatch = AccountGroupError::NotInGroup {
account: account(1),
requested_group: group(2),
current_group: Some(group(3)),
};
assert_eq!(
mismatch.to_string(),
"account 1 is not in group 2; it belongs to group 3"
);
let ungrouped = AccountGroupError::NotInGroup {
account: account(1),
requested_group: group(2),
current_group: None,
};
assert_eq!(
ungrouped.to_string(),
"account 1 is not in group 2; it belongs to no group"
);
}
fn new_handle() -> AccountGroupsHandle<NoLocking> {
AccountGroupsHandle::from_inner(NoLocking::new_shared(new_registry()))
}
#[test]
fn pre_trade_context_group_returns_bound_account_group() {
use crate::pretrade::PreTradeContext;
let handle = new_handle();
handle
.inner
.register_group(&[account(1)], group(7))
.expect("registration must succeed");
let ctx = PreTradeContext::with_groups(None, handle, Some(account(1)));
assert_eq!(ctx.account_group(), Some(group(7)));
}
#[test]
fn pre_trade_context_group_is_none_when_account_absent() {
use crate::pretrade::PreTradeContext;
let handle = new_handle();
handle
.inner
.register_group(&[account(1)], group(7))
.expect("registration must succeed");
let ctx = PreTradeContext::with_groups(None, handle, None);
assert_eq!(ctx.account_group(), None);
}
#[test]
fn pre_trade_context_group_is_cached_after_first_call() {
use crate::pretrade::PreTradeContext;
let handle = new_handle();
handle
.inner
.register_group(&[account(1)], group(7))
.expect("registration must succeed");
let ctx = PreTradeContext::with_groups(None, handle.clone(), Some(account(1)));
assert_eq!(ctx.account_group(), Some(group(7)));
handle
.inner
.unregister_group(&[account(1)], group(7))
.expect("unregistration must succeed");
handle
.inner
.register_group(&[account(1)], group(9))
.expect("re-registration must succeed");
assert_eq!(ctx.account_group(), Some(group(7)));
}
#[test]
fn post_trade_context_group_returns_report_account_group() {
use crate::pretrade::PostTradeContext;
let handle = new_handle();
handle
.inner
.register_group(&[account(5)], group(3))
.expect("registration must succeed");
let ctx = PostTradeContext::with_groups(handle, Some(account(5)));
assert_eq!(ctx.account_group(), Some(group(3)));
}
#[test]
fn account_adjustment_context_group_returns_adjusted_account_group() {
use crate::core::account_control::BlockedAccounts;
use crate::core::{AccountBlockHandle, AccountControl};
use crate::AccountAdjustmentContext;
let handle = new_handle();
handle
.inner
.register_group(&[account(8)], group(4))
.expect("registration must succeed");
let block_handle = AccountBlockHandle::from_inner(NoLocking::new_shared(
BlockedAccounts::new(&StorageBuilder::new(NoLocking)),
));
let control = AccountControl::new(block_handle, account(8));
let ctx = AccountAdjustmentContext::with_groups(control, handle, account(8));
assert_eq!(ctx.account_group(), Some(group(4)));
}
#[test]
fn full_locking_failed_register_leaves_membership_unchanged_under_contention() {
use crate::storage::FullLocking;
use std::sync::Arc;
use std::thread;
let registry: Arc<AccountGroups<FullLocking>> =
Arc::new(AccountGroups::new(&StorageBuilder::new(FullLocking)));
registry
.register_group(&[account(0)], group(100))
.expect("seed registration must succeed");
thread::scope(|scope| {
for tid in 1..=8u64 {
let registry = Arc::clone(®istry);
scope.spawn(move || {
for _ in 0..200 {
let result =
registry.register_group(&[account(tid), account(0)], group(tid as u32));
assert!(result.is_err(), "batch with conflict must fail");
}
});
}
});
for tid in 1..=8u64 {
assert_eq!(registry.group_of(account(tid)), None);
}
assert_eq!(registry.group_of(account(0)), Some(group(100)));
}
}