use crate::types::Symbol;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Position {
pub symbol: Symbol,
pub quantity: i64,
pub avg_entry_price: i64,
pub realized_pnl: i64,
pub total_cost: i64,
}
impl Position {
pub fn new(symbol: Symbol) -> Self {
Self {
symbol,
quantity: 0,
avg_entry_price: 0,
realized_pnl: 0,
total_cost: 0,
}
}
pub fn apply_fill(&mut self, qty: i64, price: i64) {
if qty == 0 {
return;
}
let same_direction = (self.quantity >= 0 && qty > 0) || (self.quantity <= 0 && qty < 0);
if self.quantity == 0 {
self.quantity = qty;
self.avg_entry_price = price;
self.total_cost = qty * price;
} else if same_direction {
self.total_cost += qty * price;
self.quantity += qty;
self.avg_entry_price = self.total_cost / self.quantity;
} else {
let close_qty = qty.abs().min(self.quantity.abs());
let pnl_per_unit = if self.quantity > 0 {
price - self.avg_entry_price } else {
self.avg_entry_price - price };
self.realized_pnl += pnl_per_unit * close_qty;
let net = self.quantity + qty;
if net == 0 {
self.quantity = 0;
self.avg_entry_price = 0;
self.total_cost = 0;
} else if (net > 0) == (self.quantity > 0) {
self.total_cost -= close_qty * self.avg_entry_price;
self.quantity = net;
self.avg_entry_price = self.total_cost / self.quantity;
} else {
self.quantity = net;
self.avg_entry_price = price;
self.total_cost = net * price;
}
}
}
#[inline]
pub fn market_value(&self, price: i64) -> i64 {
self.quantity * price
}
#[inline]
pub fn unrealized_pnl(&self, price: i64) -> i64 {
if self.quantity == 0 {
return 0;
}
(price - self.avg_entry_price) * self.quantity
}
#[inline]
pub fn is_flat(&self) -> bool {
self.quantity == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sym() -> Symbol {
Symbol::new("AAPL")
}
#[test]
fn new_position_is_flat() {
let pos = Position::new(sym());
assert!(pos.is_flat());
assert_eq!(pos.quantity, 0);
assert_eq!(pos.realized_pnl, 0);
assert_eq!(pos.unrealized_pnl(100_00), 0);
}
#[test]
fn open_long() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00);
assert_eq!(pos.quantity, 100);
assert_eq!(pos.avg_entry_price, 50_00);
assert_eq!(pos.market_value(55_00), 100 * 55_00);
assert_eq!(pos.unrealized_pnl(55_00), 100 * 5_00);
}
#[test]
fn add_to_long_vwap() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00); pos.apply_fill(100, 60_00); assert_eq!(pos.quantity, 200);
assert_eq!(pos.avg_entry_price, 55_00); }
#[test]
fn close_long_with_profit() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00); pos.apply_fill(-100, 60_00); assert!(pos.is_flat());
assert_eq!(pos.realized_pnl, 100 * 10_00); }
#[test]
fn close_long_with_loss() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00); pos.apply_fill(-100, 45_00); assert!(pos.is_flat());
assert_eq!(pos.realized_pnl, -100 * 5_00); }
#[test]
fn partial_close() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00);
pos.apply_fill(-50, 60_00); assert_eq!(pos.quantity, 50);
assert_eq!(pos.avg_entry_price, 50_00); assert_eq!(pos.realized_pnl, 50 * 10_00);
}
#[test]
fn flip_long_to_short() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00); pos.apply_fill(-150, 60_00); assert_eq!(pos.quantity, -50);
assert_eq!(pos.avg_entry_price, 60_00);
assert_eq!(pos.realized_pnl, 100 * 10_00); }
#[test]
fn short_position() {
let mut pos = Position::new(sym());
pos.apply_fill(-100, 50_00); assert_eq!(pos.quantity, -100);
assert_eq!(pos.unrealized_pnl(45_00), 100 * 5_00); assert_eq!(pos.unrealized_pnl(55_00), -100 * 5_00); }
#[test]
fn close_short_with_profit() {
let mut pos = Position::new(sym());
pos.apply_fill(-100, 50_00); pos.apply_fill(100, 40_00); assert!(pos.is_flat());
assert_eq!(pos.realized_pnl, 100 * 10_00); }
#[test]
fn zero_fill_is_noop() {
let mut pos = Position::new(sym());
pos.apply_fill(100, 50_00);
pos.apply_fill(0, 60_00);
assert_eq!(pos.quantity, 100);
assert_eq!(pos.avg_entry_price, 50_00);
}
}