use super::{define_signed_value_type, Error, ParamKind, Quantity, Side};
define_signed_value_type!(
PositionSize,
ParamKind::PositionSize
);
impl PositionSize {
pub fn from_quantity_and_side(quantity: Quantity, side: Side) -> Self {
let value = match side {
Side::Buy => quantity.to_decimal(),
Side::Sell => -quantity.to_decimal(),
};
Self::new(value)
}
pub fn to_open_quantity(self) -> (Quantity, Side) {
if self == PositionSize::ZERO {
return (Quantity::ZERO, Side::Buy);
}
let side = if self > PositionSize::ZERO {
Side::Buy
} else {
Side::Sell
};
(Quantity::new_unchecked(self.to_decimal().abs()), side)
}
pub fn to_close_quantity(self) -> (Quantity, Option<Side>) {
if self == PositionSize::ZERO {
return (Quantity::ZERO, None);
}
let side = if self < PositionSize::ZERO {
Side::Buy
} else {
Side::Sell
};
(Quantity::new_unchecked(self.to_decimal().abs()), Some(side))
}
pub fn checked_add_quantity(self, qty: Quantity, side: Side) -> Result<Self, Error> {
let delta = match side {
Side::Buy => qty.to_decimal(),
Side::Sell => -qty.to_decimal(),
};
let result = self
.to_decimal()
.checked_add(delta)
.ok_or(Error::Overflow {
param: ParamKind::PositionSize,
})?;
Ok(Self::new(result))
}
}
#[cfg(test)]
mod tests {
use super::PositionSize;
use crate::param::{Error, ParamKind, Quantity, Side};
use rust_decimal::Decimal;
fn d(value: &str) -> Decimal {
value
.parse()
.expect("decimal literal in tests must be valid")
}
#[test]
fn converts_to_open_and_close_quantities() {
let short = PositionSize::new(d("-0.5"));
let long = PositionSize::new(d("0.5"));
let expected_qty = Quantity::new(d("0.5")).expect("must be valid");
assert_eq!(short.to_open_quantity(), (expected_qty, Side::Sell));
assert_eq!(long.to_open_quantity(), (expected_qty, Side::Buy));
assert_eq!(short.to_close_quantity(), (expected_qty, Some(Side::Buy)));
assert_eq!(long.to_close_quantity(), (expected_qty, Some(Side::Sell)));
assert_eq!(
PositionSize::ZERO.to_open_quantity(),
(Quantity::ZERO, Side::Buy)
);
assert_eq!(
PositionSize::ZERO.to_close_quantity(),
(Quantity::ZERO, None)
);
}
#[test]
fn builds_from_quantity_and_side() {
let quantity = Quantity::new(d("2")).expect("must be valid");
assert_eq!(
PositionSize::from_quantity_and_side(quantity, Side::Buy),
PositionSize::new(d("2"))
);
assert_eq!(
PositionSize::from_quantity_and_side(quantity, Side::Sell),
PositionSize::new(d("-2"))
);
}
#[test]
fn supports_add_with_quantity_and_side() {
let start = PositionSize::new(d("1.5"));
let qty = Quantity::new(d("0.5")).expect("must be valid");
assert_eq!(
start
.checked_add_quantity(qty, Side::Buy)
.expect("must be valid"),
PositionSize::new(d("2.0"))
);
assert_eq!(
start
.checked_add_quantity(qty, Side::Sell)
.expect("must be valid"),
PositionSize::new(d("1.0"))
);
}
#[test]
fn add_with_quantity_flips_position_sign() {
let short = PositionSize::new(d("-1.5"));
let qty = Quantity::new(d("2.0")).expect("must be valid");
let result: PositionSize = short
.checked_add_quantity(qty, Side::Buy)
.expect("must be valid");
assert_eq!(result, PositionSize::new(d("0.5")));
let long = PositionSize::new(d("1.0"));
let result = long
.checked_add_quantity(qty, Side::Sell)
.expect("must be valid");
assert_eq!(result, PositionSize::new(d("-1.0")));
let zero_qty = Quantity::ZERO;
let result = short
.checked_add_quantity(zero_qty, Side::Buy)
.expect("must be valid");
assert_eq!(result, short);
}
#[test]
fn checked_add_quantity_reports_overflow() {
let position = PositionSize::new(Decimal::MAX);
let qty = Quantity::from_str("1").expect("must be valid");
assert_eq!(
position.checked_add_quantity(qty, Side::Buy),
Err(Error::Overflow {
param: ParamKind::PositionSize
})
);
}
}