use std::{collections::HashMap, fmt::Display};
use nautilus_core::{UUID4, UnixNanos};
use serde::{Deserialize, Serialize};
use crate::{
enums::AccountType,
identifiers::{AccountId, InstrumentId},
types::{AccountBalance, Currency, MarginBalance},
};
#[repr(C)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.model")
)]
pub struct AccountState {
pub account_id: AccountId,
pub account_type: AccountType,
pub base_currency: Option<Currency>,
pub balances: Vec<AccountBalance>,
pub margins: Vec<MarginBalance>,
pub is_reported: bool,
pub event_id: UUID4,
pub ts_event: UnixNanos,
pub ts_init: UnixNanos,
}
impl AccountState {
#[allow(clippy::too_many_arguments)]
pub fn new(
account_id: AccountId,
account_type: AccountType,
balances: Vec<AccountBalance>,
margins: Vec<MarginBalance>,
is_reported: bool,
event_id: UUID4,
ts_event: UnixNanos,
ts_init: UnixNanos,
base_currency: Option<Currency>,
) -> Self {
Self {
account_id,
account_type,
base_currency,
balances,
margins,
is_reported,
event_id,
ts_event,
ts_init,
}
}
pub fn has_same_balances_and_margins(&self, other: &Self) -> bool {
if self.balances.len() != other.balances.len() || self.margins.len() != other.margins.len()
{
return false;
}
let self_balances: HashMap<Currency, &AccountBalance> = self
.balances
.iter()
.map(|balance| (balance.currency, balance))
.collect();
let other_balances: HashMap<Currency, &AccountBalance> = other
.balances
.iter()
.map(|balance| (balance.currency, balance))
.collect();
for (currency, self_balance) in &self_balances {
match other_balances.get(currency) {
Some(other_balance) => {
if self_balance != other_balance {
return false;
}
}
None => return false, }
}
let self_margins: HashMap<InstrumentId, &MarginBalance> = self
.margins
.iter()
.map(|margin| (margin.instrument_id, margin))
.collect();
let other_margins: HashMap<InstrumentId, &MarginBalance> = other
.margins
.iter()
.map(|margin| (margin.instrument_id, margin))
.collect();
for (instrument_id, self_margin) in &self_margins {
match other_margins.get(instrument_id) {
Some(other_margin) => {
if self_margin != other_margin {
return false;
}
}
None => return false, }
}
true
}
}
impl Display for AccountState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}(account_id={}, account_type={}, base_currency={}, is_reported={}, balances=[{}], margins=[{}], event_id={})",
stringify!(AccountState),
self.account_id,
self.account_type,
self.base_currency.map_or_else(
|| "None".to_string(),
|base_currency| format!("{}", base_currency.code)
),
self.is_reported,
self.balances
.iter()
.map(|b| format!("{b}"))
.collect::<Vec<String>>()
.join(", "),
self.margins
.iter()
.map(|m| format!("{m}"))
.collect::<Vec<String>>()
.join(", "),
self.event_id
)
}
}
impl PartialEq for AccountState {
fn eq(&self, other: &Self) -> bool {
self.account_id == other.account_id
&& self.account_type == other.account_type
&& self.event_id == other.event_id
}
}
#[cfg(test)]
mod tests {
use nautilus_core::{UUID4, UnixNanos};
use rstest::rstest;
use crate::{
enums::AccountType,
events::{
AccountState,
account::stubs::{cash_account_state, margin_account_state},
},
identifiers::{AccountId, InstrumentId},
types::{AccountBalance, Currency, MarginBalance, Money},
};
#[rstest]
fn test_equality() {
let cash_account_state_1 = cash_account_state();
let cash_account_state_2 = cash_account_state();
assert_eq!(cash_account_state_1, cash_account_state_2);
}
#[rstest]
fn test_display_cash_account_state(cash_account_state: AccountState) {
let display = format!("{cash_account_state}");
assert_eq!(
display,
"AccountState(account_id=SIM-001, account_type=CASH, base_currency=USD, is_reported=true, \
balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
margins=[], event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
);
}
#[rstest]
fn test_display_margin_account_state(margin_account_state: AccountState) {
let display = format!("{margin_account_state}");
assert_eq!(
display,
"AccountState(account_id=SIM-001, account_type=MARGIN, base_currency=USD, is_reported=true, \
balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
margins=[MarginBalance(initial=5000.00 USD, maintenance=20000.00 USD, instrument_id=BTCUSDT.COINBASE)], \
event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
);
}
#[rstest]
fn test_has_same_balances_and_margins_when_identical() {
let state1 = cash_account_state();
let state2 = cash_account_state();
assert!(state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_when_different_balance_amounts() {
let state1 = cash_account_state();
let mut state2 = cash_account_state();
let usd = Currency::USD();
let different_balance = AccountBalance::new(
Money::new(2000000.0, usd),
Money::new(50000.0, usd),
Money::new(1950000.0, usd),
);
state2.balances = vec![different_balance];
assert!(!state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_when_different_balance_currencies() {
let state1 = cash_account_state();
let mut state2 = cash_account_state();
let eur = Currency::EUR();
let different_balance = AccountBalance::new(
Money::new(1525000.0, eur),
Money::new(25000.0, eur),
Money::new(1500000.0, eur),
);
state2.balances = vec![different_balance];
assert!(!state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_when_missing_balance() {
let state1 = cash_account_state();
let mut state2 = cash_account_state();
let eur = Currency::EUR();
let additional_balance = AccountBalance::new(
Money::new(1000000.0, eur),
Money::new(0.0, eur),
Money::new(1000000.0, eur),
);
state2.balances.push(additional_balance);
assert!(!state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_when_different_margin_amounts() {
let state1 = margin_account_state();
let mut state2 = margin_account_state();
let usd = Currency::USD();
let instrument_id = InstrumentId::from("BTCUSDT.COINBASE");
let different_margin = MarginBalance::new(
Money::new(10000.0, usd),
Money::new(40000.0, usd),
instrument_id,
);
state2.margins = vec![different_margin];
assert!(!state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_when_different_margin_instruments() {
let state1 = margin_account_state();
let mut state2 = margin_account_state();
let usd = Currency::USD();
let different_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
let different_margin = MarginBalance::new(
Money::new(5000.0, usd),
Money::new(20000.0, usd),
different_instrument_id,
);
state2.margins = vec![different_margin];
assert!(!state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_when_missing_margin() {
let state1 = margin_account_state();
let mut state2 = margin_account_state();
let usd = Currency::USD();
let additional_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
let additional_margin = MarginBalance::new(
Money::new(3000.0, usd),
Money::new(15000.0, usd),
additional_instrument_id,
);
state2.margins.push(additional_margin);
assert!(!state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_with_empty_collections() {
let account_id = AccountId::new("TEST-001");
let event_id = UUID4::new();
let ts_event = UnixNanos::from(1);
let ts_init = UnixNanos::from(2);
let state1 = AccountState::new(
account_id,
AccountType::Cash,
vec![], vec![], true,
event_id,
ts_event,
ts_init,
Some(Currency::USD()),
);
let state2 = AccountState::new(
account_id,
AccountType::Cash,
vec![], vec![], true,
UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
Some(Currency::USD()),
);
assert!(state1.has_same_balances_and_margins(&state2));
}
#[rstest]
fn test_has_same_balances_and_margins_with_multiple_balances_and_margins() {
let account_id = AccountId::new("TEST-001");
let event_id = UUID4::new();
let ts_event = UnixNanos::from(1);
let ts_init = UnixNanos::from(2);
let usd = Currency::USD();
let eur = Currency::EUR();
let btc_instrument = InstrumentId::from("BTCUSDT.COINBASE");
let eth_instrument = InstrumentId::from("ETHUSDT.BINANCE");
let balances = vec![
AccountBalance::new(
Money::new(1000000.0, usd),
Money::new(0.0, usd),
Money::new(1000000.0, usd),
),
AccountBalance::new(
Money::new(500000.0, eur),
Money::new(10000.0, eur),
Money::new(490000.0, eur),
),
];
let margins = vec![
MarginBalance::new(
Money::new(5000.0, usd),
Money::new(20000.0, usd),
btc_instrument,
),
MarginBalance::new(
Money::new(3000.0, usd),
Money::new(15000.0, usd),
eth_instrument,
),
];
let state1 = AccountState::new(
account_id,
AccountType::Margin,
balances.clone(),
margins.clone(),
true,
event_id,
ts_event,
ts_init,
Some(usd),
);
let state2 = AccountState::new(
account_id,
AccountType::Margin,
balances,
margins,
true,
UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
Some(usd),
);
assert!(state1.has_same_balances_and_margins(&state2));
}
}