use chrono::{DateTime, Utc};
use derive_more::Constructor;
use indexmap::IndexMap;
use rust_decimal::Decimal;
pub use rustrade_execution::order::id::PositionId;
use rustrade_execution::trade::{AssetFees, Trade, TradeId};
use rustrade_instrument::{Side, asset::AssetIndex, instrument::InstrumentIndex};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use tracing::error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
pub enum OmsMode {
#[default]
Netting,
Hedging,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct PositionManager<AssetKey = AssetIndex, InstrumentKey = InstrumentIndex> {
pub mode: OmsMode,
pub positions: IndexMap<PositionId, Position<AssetKey, InstrumentKey>>,
}
impl<AssetKey, InstrumentKey> Default for PositionManager<AssetKey, InstrumentKey> {
fn default() -> Self {
Self::new(OmsMode::Netting)
}
}
impl<AssetKey, InstrumentKey> PositionManager<AssetKey, InstrumentKey> {
pub fn new(mode: OmsMode) -> Self {
Self {
mode,
positions: IndexMap::with_capacity(1),
}
}
}
impl<AssetKey: Debug + Clone, InstrumentKey> PositionManager<AssetKey, InstrumentKey> {
pub fn update_from_trade(
&mut self,
trade: &Trade<AssetKey, InstrumentKey>,
contract_size: Decimal,
) -> Option<PositionExited<AssetKey, InstrumentKey>>
where
InstrumentKey: Debug + Clone + PartialEq,
{
let position_id = match self.mode {
OmsMode::Netting => PositionId::NETTING,
OmsMode::Hedging => PositionId::new(trade.order_id.0.clone()),
};
self.update_from_trade_with_id(trade, &position_id, contract_size)
}
pub fn update_from_trade_with_id(
&mut self,
trade: &Trade<AssetKey, InstrumentKey>,
position_id: &PositionId,
contract_size: Decimal,
) -> Option<PositionExited<AssetKey, InstrumentKey>>
where
InstrumentKey: Debug + Clone + PartialEq,
{
let (updated, closed) = match self.positions.shift_remove(position_id) {
Some(mut position) => {
position.contract_size = contract_size;
position.update_from_trade(trade)
}
None => {
let mut position = Position::from(trade);
position.contract_size = contract_size;
(Some(position), None)
}
};
if self.mode == OmsMode::Hedging
&& let (Some(new_pos), Some(exited)) = (&updated, &closed)
{
error!(
%position_id,
closed_side = ?exited.side,
new_side = ?new_pos.side,
"Hedging mode: fill crossed zero — position flipped under same PositionId. \
Subsequent fills routed to this ID will update the opposite-direction \
position silently, corrupting PnL. Strategies MUST close positions \
explicitly; never rely on flip semantics in Hedging mode.",
);
}
let closed = closed.map(|mut e| {
e.position_id = position_id.clone();
e
});
if let Some(pos) = updated {
self.positions.insert(position_id.clone(), pos);
}
closed
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Deserialize, Serialize, Constructor)]
pub struct Position<AssetKey = AssetIndex, InstrumentKey = InstrumentIndex> {
pub instrument: InstrumentKey,
pub side: Side,
pub price_entry_average: Decimal,
pub quantity_abs: Decimal,
pub quantity_abs_max: Decimal,
pub pnl_unrealised: Decimal,
pub pnl_realised: Decimal,
pub fees_enter: AssetFees<AssetKey>,
pub fees_exit: AssetFees<AssetKey>,
pub time_enter: DateTime<Utc>,
pub time_exchange_update: DateTime<Utc>,
pub trades: Vec<TradeId>,
#[serde(default = "default_contract_size")]
pub contract_size: Decimal,
}
fn default_contract_size() -> Decimal {
Decimal::ONE
}
impl<AssetKey, InstrumentKey> Position<AssetKey, InstrumentKey> {
pub fn update_from_trade(
mut self,
trade: &Trade<AssetKey, InstrumentKey>,
) -> (
Option<Self>,
Option<PositionExited<AssetKey, InstrumentKey>>,
)
where
AssetKey: Debug + Clone,
InstrumentKey: Debug + Clone + PartialEq,
{
if self.instrument != trade.instrument {
error!(
position = ?self,
trade = ?trade,
"Position tried to be updated from a Trade for a different Instrument - ignoring"
);
return (Some(self), None);
}
self.trades.push(trade.id.clone());
use Side::*;
match (self.side, trade.side) {
(Buy, Buy) | (Sell, Sell) => {
self.update_price_entry_average(trade);
self.quantity_abs += trade.quantity.abs();
if self.quantity_abs > self.quantity_abs_max {
self.quantity_abs_max = self.quantity_abs;
}
self.pnl_realised -= trade.fees.fees_quote.unwrap_or(trade.fees.fees);
self.fees_enter.fees += trade.fees.fees;
self.fees_enter.fees_quote =
match (self.fees_enter.fees_quote, trade.fees.fees_quote) {
(Some(a), Some(b)) => Some(a + b),
_ => None,
};
self.time_exchange_update = trade.time_exchange;
self.update_pnl_unrealised(trade.price);
(Some(self), None)
}
(Buy, Sell) | (Sell, Buy) if self.quantity_abs > trade.quantity.abs() => {
let closed_fee_quote = trade.fees.fees_quote.unwrap_or(trade.fees.fees);
self.update_pnl_realised(trade.quantity, trade.price, closed_fee_quote);
self.quantity_abs -= trade.quantity.abs();
self.fees_exit.fees += trade.fees.fees;
self.fees_exit.fees_quote = match (self.fees_exit.fees_quote, trade.fees.fees_quote)
{
(Some(a), Some(b)) => Some(a + b),
_ => None,
};
self.time_exchange_update = trade.time_exchange;
self.update_pnl_unrealised(trade.price);
(Some(self), None)
}
(Buy, Sell) | (Sell, Buy) if self.quantity_abs == trade.quantity.abs() => {
self.quantity_abs -= trade.quantity.abs();
self.fees_exit.fees += trade.fees.fees;
self.fees_exit.fees_quote = match (self.fees_exit.fees_quote, trade.fees.fees_quote)
{
(Some(a), Some(b)) => Some(a + b),
_ => None,
};
self.time_exchange_update = trade.time_exchange;
let closed_fee_quote = trade.fees.fees_quote.unwrap_or(trade.fees.fees);
self.update_pnl_realised(trade.quantity, trade.price, closed_fee_quote);
self.update_pnl_unrealised(trade.price);
(None, Some(PositionExited::from(self)))
}
(Buy, Sell) | (Sell, Buy) if self.quantity_abs < trade.quantity.abs() => {
let next_position_quantity = trade.quantity.abs() - self.quantity_abs;
let next_position_fee_enter =
trade.fees.fees * (next_position_quantity / trade.quantity.abs());
let next_position_trade = Trade {
id: trade.id.clone(),
order_id: trade.order_id.clone(),
instrument: trade.instrument.clone(),
strategy: trade.strategy.clone(),
time_exchange: trade.time_exchange,
side: trade.side,
price: trade.price,
quantity: next_position_quantity,
fees: AssetFees {
asset: trade.fees.asset.clone(),
fees: next_position_fee_enter,
fees_quote: trade
.fees
.fees_quote
.map(|fq| fq * (next_position_quantity / trade.quantity.abs())),
},
};
let fee_exit = trade.fees.fees * (self.quantity_abs / trade.quantity.abs());
let fee_exit_quote = trade
.fees
.fees_quote
.map(|fq| fq * (self.quantity_abs / trade.quantity.abs()));
self.fees_exit.fees += fee_exit;
self.fees_exit.fees_quote = match (self.fees_exit.fees_quote, fee_exit_quote) {
(Some(a), Some(b)) => Some(a + b),
_ => None,
};
self.time_exchange_update = trade.time_exchange;
self.update_pnl_realised(
self.quantity_abs,
trade.price,
fee_exit_quote.unwrap_or(fee_exit),
);
self.quantity_abs = Decimal::ZERO;
self.update_pnl_unrealised(trade.price);
let mut next_position = Self::from(&next_position_trade);
next_position.contract_size = self.contract_size;
(Some(next_position), Some(PositionExited::from(self)))
}
_ => unreachable!("match expression guard statements cover all cases"),
}
}
fn update_price_entry_average(&mut self, trade: &Trade<AssetKey, InstrumentKey>) {
self.price_entry_average = calculate_price_entry_average(
self.price_entry_average,
self.quantity_abs,
trade.price,
trade.quantity.abs(),
);
}
pub fn update_pnl_unrealised(&mut self, price: Decimal) {
self.pnl_unrealised = calculate_pnl_unrealised(
self.side,
self.price_entry_average,
self.quantity_abs,
self.quantity_abs_max,
self.fees_enter.fees,
price,
self.contract_size,
);
}
pub fn update_pnl_realised(
&mut self,
closed_quantity: Decimal,
closed_price: Decimal,
closed_fee: Decimal,
) {
self.pnl_realised += calculate_pnl_realised(
self.side,
self.price_entry_average,
closed_quantity,
closed_price,
closed_fee,
self.contract_size,
);
}
}
impl<AssetKey, InstrumentKey> From<&Trade<AssetKey, InstrumentKey>>
for Position<AssetKey, InstrumentKey>
where
AssetKey: Clone,
InstrumentKey: Clone,
{
fn from(trade: &Trade<AssetKey, InstrumentKey>) -> Self {
let mut trades = Vec::with_capacity(2);
trades.push(trade.id.clone());
Self {
instrument: trade.instrument.clone(),
side: trade.side,
price_entry_average: trade.price,
quantity_abs: trade.quantity.abs(),
quantity_abs_max: trade.quantity.abs(),
pnl_unrealised: Decimal::ZERO,
pnl_realised: -trade.fees.fees_quote.unwrap_or(trade.fees.fees),
fees_enter: trade.fees.clone(),
fees_exit: AssetFees::new(trade.fees.asset.clone(), Decimal::ZERO, Some(Decimal::ZERO)),
time_enter: trade.time_exchange,
time_exchange_update: trade.time_exchange,
trades,
contract_size: Decimal::ONE,
}
}
}
#[derive(
Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize, Constructor,
)]
pub struct PositionExited<AssetKey, InstrumentKey = InstrumentIndex> {
#[serde(default)]
pub position_id: PositionId,
pub instrument: InstrumentKey,
pub side: Side,
pub price_entry_average: Decimal,
pub quantity_abs_max: Decimal,
pub pnl_realised: Decimal,
pub fees_enter: AssetFees<AssetKey>,
pub fees_exit: AssetFees<AssetKey>,
pub time_enter: DateTime<Utc>,
pub time_exit: DateTime<Utc>,
pub trades: Vec<TradeId>,
}
impl<AssetKey, InstrumentKey> From<Position<AssetKey, InstrumentKey>>
for PositionExited<AssetKey, InstrumentKey>
{
fn from(value: Position<AssetKey, InstrumentKey>) -> Self {
Self {
position_id: PositionId::default(),
instrument: value.instrument,
side: value.side,
price_entry_average: value.price_entry_average,
quantity_abs_max: value.quantity_abs_max,
pnl_realised: value.pnl_realised,
fees_enter: value.fees_enter,
fees_exit: value.fees_exit,
time_enter: value.time_enter,
time_exit: value.time_exchange_update,
trades: value.trades,
}
}
}
fn calculate_price_entry_average(
current_price_entry_average: Decimal,
current_quantity_abs: Decimal,
trade_price: Decimal,
trade_quantity_abs: Decimal,
) -> Decimal {
if current_quantity_abs.is_zero() && trade_quantity_abs.is_zero() {
return Decimal::ZERO;
}
let current_value = current_price_entry_average * current_quantity_abs;
let trade_value = trade_price * trade_quantity_abs;
(current_value + trade_value) / (current_quantity_abs + trade_quantity_abs)
}
pub fn calculate_pnl_unrealised(
position_side: Side,
price_entry_average: Decimal,
quantity_abs: Decimal,
quantity_abs_max: Decimal,
fees_enter: Decimal,
price: Decimal,
contract_size: Decimal,
) -> Decimal {
let approx_exit_fees =
approximate_remaining_exit_fees(quantity_abs, quantity_abs_max, fees_enter);
let value_quote_current = quantity_abs * price * contract_size;
let value_quote_entry = quantity_abs * price_entry_average * contract_size;
match position_side {
Side::Buy => value_quote_current - value_quote_entry - approx_exit_fees,
Side::Sell => value_quote_entry - value_quote_current - approx_exit_fees,
}
}
fn approximate_remaining_exit_fees(
quantity_abs: Decimal,
quantity_abs_max: Decimal,
fees_enter: Decimal,
) -> Decimal {
if quantity_abs_max.is_zero() {
return Decimal::ZERO;
}
(quantity_abs / quantity_abs_max) * fees_enter
}
pub fn calculate_pnl_realised(
position_side: Side,
price_entry_average: Decimal,
closed_quantity: Decimal,
closed_price: Decimal,
closed_fee: Decimal,
contract_size: Decimal,
) -> Decimal {
let close_quantity = closed_quantity.abs();
let value_quote_closed = close_quantity * closed_price * contract_size;
let value_quote_entry = close_quantity * price_entry_average * contract_size;
match position_side {
Side::Buy => value_quote_closed - value_quote_entry - closed_fee,
Side::Sell => value_quote_entry - value_quote_closed - closed_fee,
}
}
pub fn calculate_pnl_return(
pnl_realised: Decimal,
price_entry_average: Decimal,
quantity_abs_max: Decimal,
) -> Decimal {
if price_entry_average.is_zero() || quantity_abs_max.is_zero() {
return Decimal::ZERO;
}
pnl_realised / (price_entry_average * quantity_abs_max)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{time_plus_days, trade};
use rust_decimal_macros::dec;
use rustrade_instrument::{asset::QuoteAsset, instrument::name::InstrumentNameInternal};
#[test]
fn test_position_update_from_trade() {
struct TestCase {
initial_trade: Trade<QuoteAsset, InstrumentNameInternal>,
update_trade: Trade<QuoteAsset, InstrumentNameInternal>,
expected_position: Option<Position<QuoteAsset, InstrumentNameInternal>>,
expected_position_exited: Option<PositionExited<QuoteAsset, InstrumentNameInternal>>,
}
let base_time = DateTime::<Utc>::MIN_UTC;
let cases = vec![
TestCase {
initial_trade: trade(base_time, Side::Buy, 100.0, 1.0, 10.0),
update_trade: trade(time_plus_days(base_time, 1), Side::Buy, 120.0, 1.0, 10.0),
expected_position: Some(Position {
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Buy,
price_entry_average: dec!(110.0),
quantity_abs: dec!(2.0),
quantity_abs_max: dec!(2.0),
pnl_unrealised: dec!(0.0),
pnl_realised: dec!(-20.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(20.0),
fees_quote: Some(dec!(20.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(0.0),
fees_quote: Some(dec!(0.0)),
},
time_enter: base_time,
time_exchange_update: time_plus_days(base_time, 1),
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
contract_size: Decimal::ONE,
}),
expected_position_exited: None,
},
TestCase {
initial_trade: trade(base_time, Side::Buy, 100.0, 2.0, 10.0),
update_trade: trade(time_plus_days(base_time, 1), Side::Sell, 150.0, 0.5, 5.0),
expected_position: Some(Position {
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Buy,
price_entry_average: dec!(100.0), quantity_abs: dec!(1.5),
quantity_abs_max: dec!(2.0),
pnl_unrealised: dec!(67.5), pnl_realised: dec!(10.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(5.0),
fees_quote: Some(dec!(5.0)),
},
time_enter: base_time,
time_exchange_update: time_plus_days(base_time, 1),
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
contract_size: Decimal::ONE,
}),
expected_position_exited: None,
},
TestCase {
initial_trade: trade(base_time, Side::Buy, 100.0, 1.0, 10.0),
update_trade: trade(time_plus_days(base_time, 1), Side::Sell, 150.0, 1.0, 10.0),
expected_position: None,
expected_position_exited: Some(PositionExited {
position_id: PositionId::NETTING,
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Buy,
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
pnl_realised: dec!(30.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
time_enter: base_time,
time_exit: time_plus_days(base_time, 1),
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
}),
},
TestCase {
initial_trade: trade(base_time, Side::Buy, 100.0, 1.0, 10.0),
update_trade: trade(time_plus_days(base_time, 1), Side::Sell, 150.0, 2.0, 20.0),
expected_position: Some(Position {
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Sell,
price_entry_average: dec!(150.0),
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
pnl_unrealised: dec!(0.0),
pnl_realised: dec!(-10.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(0.0),
fees_quote: Some(dec!(0.0)),
},
time_enter: time_plus_days(base_time, 1),
time_exchange_update: time_plus_days(base_time, 1),
trades: vec![TradeId::new("trade_id")],
contract_size: Decimal::ONE,
}),
expected_position_exited: Some(PositionExited {
position_id: PositionId::NETTING,
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Buy,
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
pnl_realised: dec!(30.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
time_enter: base_time,
time_exit: time_plus_days(base_time, 1),
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
}),
},
TestCase {
initial_trade: trade(base_time, Side::Sell, 100.0, 1.0, 10.0),
update_trade: trade(base_time, Side::Sell, 80.0, 1.0, 10.0),
expected_position: Some(Position {
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Sell,
price_entry_average: dec!(90.0), quantity_abs: dec!(2.0),
quantity_abs_max: dec!(2.0),
pnl_unrealised: dec!(0.0), pnl_realised: dec!(-20.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(20.0),
fees_quote: Some(dec!(20.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(0.0),
fees_quote: Some(dec!(0.0)),
},
time_enter: base_time,
time_exchange_update: base_time,
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
contract_size: Decimal::ONE,
}),
expected_position_exited: None,
},
TestCase {
initial_trade: trade(base_time, Side::Sell, 100.0, 2.0, 10.0),
update_trade: trade(base_time, Side::Buy, 80.0, 0.5, 5.0),
expected_position: Some(Position {
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Sell,
price_entry_average: dec!(100.0), quantity_abs: dec!(1.5),
quantity_abs_max: dec!(2.0),
pnl_unrealised: dec!(22.5), pnl_realised: dec!(-5.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(5.0),
fees_quote: Some(dec!(5.0)),
},
time_enter: base_time,
time_exchange_update: base_time,
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
contract_size: Decimal::ONE,
}),
expected_position_exited: None,
},
TestCase {
initial_trade: trade(base_time, Side::Sell, 100.0, 1.0, 10.0),
update_trade: trade(base_time, Side::Buy, 80.0, 1.0, 10.0),
expected_position: None,
expected_position_exited: Some(PositionExited {
position_id: PositionId::NETTING,
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Sell,
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
pnl_realised: dec!(0.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
time_enter: base_time,
time_exit: base_time,
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
}),
},
TestCase {
initial_trade: trade(base_time, Side::Sell, 100.0, 1.0, 10.0),
update_trade: trade(base_time, Side::Buy, 80.0, 2.0, 20.0),
expected_position: Some(Position {
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Buy,
price_entry_average: dec!(80.0),
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
pnl_unrealised: dec!(0.0),
pnl_realised: dec!(-10.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(0.0),
fees_quote: Some(dec!(0.0)),
},
time_enter: base_time,
time_exchange_update: base_time,
trades: vec![TradeId::new("trade_id")],
contract_size: Decimal::ONE,
}),
expected_position_exited: Some(PositionExited {
position_id: PositionId::NETTING,
instrument: InstrumentNameInternal::new("instrument"),
side: Side::Sell,
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
pnl_realised: dec!(0.0), fees_enter: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
fees_exit: AssetFees {
asset: QuoteAsset,
fees: dec!(10.0),
fees_quote: Some(dec!(10.0)),
},
time_enter: base_time,
time_exit: base_time,
trades: vec![TradeId::new("trade_id"), TradeId::new("trade_id")],
}),
},
];
for (index, test) in cases.into_iter().enumerate() {
let position = Position::from(&test.initial_trade);
let (updated_position, exited_position) =
position.update_from_trade(&test.update_trade);
assert_eq!(updated_position, test.expected_position, "TC{index} failed");
assert_eq!(
exited_position, test.expected_position_exited,
"TC{index} failed"
);
}
}
#[test]
fn test_calculate_price_entry_average() {
struct TestCase {
current_price_entry_average: Decimal,
current_quantity_abs: Decimal,
trade_price: Decimal,
trade_quantity_abs: Decimal,
expected: Decimal,
}
let cases = vec![
TestCase {
current_price_entry_average: dec!(100.0),
current_quantity_abs: dec!(2.0),
trade_price: dec!(200.0),
trade_quantity_abs: dec!(2.0),
expected: dec!(150.0),
},
TestCase {
current_price_entry_average: dec!(100.0),
current_quantity_abs: dec!(2.0),
trade_price: dec!(200.0),
trade_quantity_abs: dec!(4.0),
expected: dec!(166.66666666666666666666666667),
},
TestCase {
current_price_entry_average: dec!(100.0),
current_quantity_abs: dec!(20.0),
trade_price: dec!(200.0),
trade_quantity_abs: dec!(1.0),
expected: dec!(104.76190476190476190476190476),
},
TestCase {
current_price_entry_average: dec!(100.0),
current_quantity_abs: dec!(0.0),
trade_price: dec!(200.0),
trade_quantity_abs: dec!(4.0),
expected: dec!(200.0),
},
TestCase {
current_price_entry_average: dec!(100.0),
current_quantity_abs: dec!(10.0),
trade_price: dec!(0.0),
trade_quantity_abs: dec!(0.0),
expected: dec!(100.0),
},
TestCase {
current_price_entry_average: dec!(100.0),
current_quantity_abs: dec!(0.0),
trade_price: dec!(200.0),
trade_quantity_abs: dec!(0.0),
expected: dec!(0.0),
},
];
for (index, test) in cases.into_iter().enumerate() {
let actual = calculate_price_entry_average(
test.current_price_entry_average,
test.current_quantity_abs,
test.trade_price,
test.trade_quantity_abs,
);
assert_eq!(actual, test.expected, "TC{} failed", index)
}
}
#[test]
fn test_calculate_pnl_unrealised() {
struct TestCase {
position_side: Side,
price_entry_average: Decimal,
quantity_abs: Decimal,
quantity_abs_max: Decimal,
fees_enter: Decimal,
price: Decimal,
expected: Decimal,
}
let cases = vec![
TestCase {
position_side: Side::Buy,
price_entry_average: dec!(100.0),
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
price: dec!(150.0),
expected: dec!(40.0), },
TestCase {
position_side: Side::Buy,
price_entry_average: dec!(100.0),
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
price: dec!(80.0),
expected: dec!(-30.0), },
TestCase {
position_side: Side::Sell,
price_entry_average: dec!(100.0),
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
price: dec!(80.0),
expected: dec!(10.0), },
TestCase {
position_side: Side::Sell,
price_entry_average: dec!(100.0),
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
price: dec!(150.0),
expected: dec!(-60.0), },
TestCase {
position_side: Side::Buy,
price_entry_average: dec!(100.0),
quantity_abs: dec!(0.5),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
price: dec!(150.0),
expected: dec!(20.0), },
TestCase {
position_side: Side::Buy,
price_entry_average: dec!(100.0),
quantity_abs: dec!(0.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
price: dec!(150.0),
expected: dec!(0.0),
},
];
for (index, test) in cases.into_iter().enumerate() {
let actual = calculate_pnl_unrealised(
test.position_side,
test.price_entry_average,
test.quantity_abs,
test.quantity_abs_max,
test.fees_enter,
test.price,
Decimal::ONE,
);
assert_eq!(actual, test.expected, "TC{} failed", index);
}
}
#[test]
fn test_calculate_pnl_unrealised_with_contract_size() {
let pnl = calculate_pnl_unrealised(
Side::Buy,
dec!(10.0), dec!(1.0), dec!(1.0), dec!(1.0), dec!(15.0), dec!(100.0), );
assert_eq!(pnl, dec!(499.0));
let pnl = calculate_pnl_unrealised(
Side::Sell,
dec!(5.0), dec!(2.0), dec!(2.0), dec!(2.0), dec!(3.0), dec!(100.0), );
assert_eq!(pnl, dec!(398.0));
}
#[test]
fn test_approximate_remaining_exit_fees() {
struct TestCase {
quantity_abs: Decimal,
quantity_abs_max: Decimal,
fees_enter: Decimal,
expected: Decimal,
}
let cases = vec![
TestCase {
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
expected: dec!(10.0),
},
TestCase {
quantity_abs: dec!(0.5),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
expected: dec!(5.0),
},
TestCase {
quantity_abs: dec!(0.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
expected: dec!(0.0),
},
TestCase {
quantity_abs: dec!(2.0),
quantity_abs_max: dec!(1.0),
fees_enter: dec!(10.0),
expected: dec!(20.0),
},
TestCase {
quantity_abs: dec!(1.0),
quantity_abs_max: dec!(0.0),
fees_enter: dec!(10.0),
expected: dec!(0.0),
},
];
for (index, test) in cases.into_iter().enumerate() {
let actual = approximate_remaining_exit_fees(
test.quantity_abs,
test.quantity_abs_max,
test.fees_enter,
);
assert_eq!(actual, test.expected, "TC{} failed", index);
}
}
#[test]
fn test_calculate_pnl_realised() {
struct TestCase {
side: Side,
price_entry_average: Decimal,
closed_quantity: Decimal,
closed_price: Decimal,
closed_fee: Decimal,
expected: Decimal,
}
let cases = vec![
TestCase {
side: Side::Buy,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(150.0),
closed_fee: dec!(5.0),
expected: dec!(495.0),
},
TestCase {
side: Side::Buy,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(150.0),
closed_fee: dec!(0.0),
expected: dec!(500.0),
},
TestCase {
side: Side::Buy,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(150.0),
closed_fee: dec!(-5.0),
expected: dec!(505.0),
},
TestCase {
side: Side::Buy,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(50.0),
closed_fee: dec!(5.0),
expected: dec!(-505.0),
},
TestCase {
side: Side::Buy,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(50.0),
closed_fee: dec!(0.0),
expected: dec!(-500.0),
},
TestCase {
side: Side::Buy,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(50.0),
closed_fee: dec!(-5.0),
expected: dec!(-495.0),
},
TestCase {
side: Side::Sell,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(50.0),
closed_fee: dec!(5.0),
expected: dec!(495.0),
},
TestCase {
side: Side::Sell,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(50.0),
closed_fee: dec!(0.0),
expected: dec!(500.0),
},
TestCase {
side: Side::Sell,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(50.0),
closed_fee: dec!(-5.0),
expected: dec!(505.0),
},
TestCase {
side: Side::Sell,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(150.0),
closed_fee: dec!(5.0),
expected: dec!(-505.0),
},
TestCase {
side: Side::Sell,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(150.0),
closed_fee: dec!(0.0),
expected: dec!(-500.0),
},
TestCase {
side: Side::Sell,
price_entry_average: dec!(100.0),
closed_quantity: dec!(10.0),
closed_price: dec!(150.0),
closed_fee: dec!(-5.0),
expected: dec!(-495.0),
},
];
for (index, test) in cases.into_iter().enumerate() {
let actual = calculate_pnl_realised(
test.side,
test.price_entry_average,
test.closed_quantity,
test.closed_price,
test.closed_fee,
Decimal::ONE,
);
assert_eq!(actual, test.expected, "TC{} failed", index);
}
}
#[test]
fn test_calculate_pnl_realised_with_contract_size() {
let pnl = calculate_pnl_realised(
Side::Buy,
dec!(10.0), dec!(1.0), dec!(15.0), dec!(1.0), dec!(100.0), );
assert_eq!(pnl, dec!(499.0));
let pnl = calculate_pnl_realised(
Side::Sell,
dec!(5.0), dec!(2.0), dec!(3.0), dec!(2.0), dec!(100.0), );
assert_eq!(pnl, dec!(398.0));
}
#[test]
fn test_calculate_pnl_return() {
struct TestCase {
pnl_realised: Decimal,
price_entry_average: Decimal,
quantity_abs_max: Decimal,
expected: Decimal,
}
let cases = vec![
TestCase {
pnl_realised: dec!(0.0),
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
expected: dec!(0.0),
},
TestCase {
pnl_realised: dec!(100.0),
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
expected: dec!(1.0),
},
TestCase {
pnl_realised: dec!(-50.0),
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(1.0),
expected: dec!(-0.5),
},
TestCase {
pnl_realised: dec!(500.0),
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(10.0),
expected: dec!(0.5), },
TestCase {
pnl_realised: dec!(100.0),
price_entry_average: dec!(0.0),
quantity_abs_max: dec!(1.0),
expected: dec!(0.0),
},
TestCase {
pnl_realised: dec!(100.0),
price_entry_average: dec!(100.0),
quantity_abs_max: dec!(0.0),
expected: dec!(0.0),
},
];
for (index, test) in cases.into_iter().enumerate() {
let actual = calculate_pnl_return(
test.pnl_realised,
test.price_entry_average,
test.quantity_abs_max,
);
assert_eq!(actual, test.expected, "TC{} failed", index);
}
}
}