use positive::Positive;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use std::fmt;
use utoipa::ToSchema;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize, Deserialize, ToSchema)]
pub enum ExitPolicy {
ProfitPercent(Decimal),
LossPercent(Decimal),
FixedPrice(Positive),
MinPrice(Positive),
MaxPrice(Positive),
TimeSteps(usize),
DaysToExpiration(Positive),
DeltaThreshold(Decimal),
UnderlyingPrice(Positive),
UnderlyingBelow(Positive),
UnderlyingAbove(Positive),
#[default]
Expiration,
And(Vec<ExitPolicy>),
Or(Vec<ExitPolicy>),
}
impl ExitPolicy {
#[must_use]
pub fn profit_target(percent: Decimal) -> Self {
Self::ProfitPercent(percent)
}
#[must_use]
pub fn stop_loss(percent: Decimal) -> Self {
Self::LossPercent(percent)
}
#[must_use]
pub fn profit_or_loss(profit_percent: Decimal, loss_percent: Decimal) -> Self {
Self::Or(vec![
Self::ProfitPercent(profit_percent),
Self::LossPercent(loss_percent),
])
}
#[must_use]
pub fn profit_or_time(profit_percent: Decimal, max_steps: usize) -> Self {
Self::Or(vec![
Self::ProfitPercent(profit_percent),
Self::TimeSteps(max_steps),
])
}
#[must_use]
pub const fn is_composite(&self) -> bool {
matches!(self, Self::And(_) | Self::Or(_))
}
#[must_use]
pub fn condition_count(&self) -> usize {
match self {
Self::And(policies) | Self::Or(policies) => {
policies.iter().map(Self::condition_count).sum()
}
_ => 1,
}
}
}
impl fmt::Display for ExitPolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ProfitPercent(pct) => {
write!(f, "Profit Target: {:.1}%", pct * Decimal::from(100))
}
Self::LossPercent(pct) => write!(f, "Stop Loss: {:.1}%", pct * Decimal::from(100)),
Self::FixedPrice(price) => write!(f, "Fixed Price: ${price}"),
Self::MinPrice(price) => write!(f, "Min Price: ${price}"),
Self::MaxPrice(price) => write!(f, "Max Price: ${price}"),
Self::TimeSteps(steps) => write!(f, "Time Steps: {steps}"),
Self::DaysToExpiration(days) => write!(f, "Days to Expiration: {days}"),
Self::DeltaThreshold(delta) => write!(f, "Delta Threshold: {delta}"),
Self::UnderlyingPrice(price) => write!(f, "Underlying Price: ${price}"),
Self::UnderlyingBelow(price) => write!(f, "Underlying Below: ${price}"),
Self::UnderlyingAbove(price) => write!(f, "Underlying Above: ${price}"),
Self::Expiration => write!(f, "Hold to Expiration"),
Self::And(policies) => {
write!(f, "AND(")?;
for (i, policy) in policies.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{policy}")?;
}
write!(f, ")")
}
Self::Or(policies) => {
write!(f, "OR(")?;
for (i, policy) in policies.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{policy}")?;
}
write!(f, ")")
}
}
}
}
#[allow(clippy::only_used_in_recursion)]
#[must_use]
pub fn check_exit_policy(
policy: &ExitPolicy,
initial_premium: Decimal,
current_premium: Decimal,
step_num: usize,
days_left: Positive,
underlying_price: Positive,
is_long: bool,
) -> Option<ExitPolicy> {
match policy {
ExitPolicy::ProfitPercent(pct) => {
if is_long {
let target = initial_premium * (Decimal::ONE + pct);
if current_premium >= target {
Some(ExitPolicy::ProfitPercent(*pct))
} else {
None
}
} else {
let target = initial_premium * (Decimal::ONE - pct);
if current_premium <= target {
Some(ExitPolicy::ProfitPercent(*pct))
} else {
None
}
}
}
ExitPolicy::LossPercent(pct) => {
if is_long {
let limit = initial_premium * (Decimal::ONE - pct);
if current_premium <= limit {
Some(ExitPolicy::LossPercent(*pct))
} else {
None
}
} else {
let limit = initial_premium * (Decimal::ONE + pct);
if current_premium >= limit {
Some(ExitPolicy::LossPercent(*pct))
} else {
None
}
}
}
ExitPolicy::FixedPrice(price) => {
if (current_premium - price.to_dec()).abs() < dec!(0.01) {
Some(ExitPolicy::FixedPrice(*price))
} else {
None
}
}
ExitPolicy::MinPrice(price) => {
if current_premium <= price.to_dec() {
Some(ExitPolicy::MinPrice(*price))
} else {
None
}
}
ExitPolicy::MaxPrice(price) => {
if current_premium >= price.to_dec() {
Some(ExitPolicy::MaxPrice(*price))
} else {
None
}
}
ExitPolicy::TimeSteps(steps) => {
if step_num >= *steps {
Some(ExitPolicy::TimeSteps(*steps))
} else {
None
}
}
ExitPolicy::DaysToExpiration(days) => {
if days_left <= *days {
Some(ExitPolicy::DaysToExpiration(*days))
} else {
None
}
}
ExitPolicy::UnderlyingPrice(price) => {
if (underlying_price.to_dec() - price.to_dec()).abs() < dec!(0.01) {
Some(ExitPolicy::UnderlyingPrice(*price))
} else {
None
}
}
ExitPolicy::UnderlyingBelow(price) => {
if underlying_price < *price {
Some(ExitPolicy::UnderlyingBelow(*price))
} else {
None
}
}
ExitPolicy::UnderlyingAbove(price) => {
if underlying_price > *price {
Some(ExitPolicy::UnderlyingAbove(*price))
} else {
None
}
}
ExitPolicy::Expiration => None, ExitPolicy::DeltaThreshold(_) => None, ExitPolicy::And(policies) => {
let mut triggered = Vec::new();
for p in policies {
if let Some(triggered_policy) = check_exit_policy(
p,
initial_premium,
current_premium,
step_num,
days_left,
underlying_price,
is_long,
) {
triggered.push(triggered_policy);
} else {
return None; }
}
Some(ExitPolicy::And(triggered))
}
ExitPolicy::Or(policies) => {
for p in policies {
if let Some(triggered_policy) = check_exit_policy(
p,
initial_premium,
current_premium,
step_num,
days_left,
underlying_price,
is_long,
) {
return Some(triggered_policy);
}
}
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use positive::pos_or_panic;
use rust_decimal_macros::dec;
#[test]
fn test_profit_target_creation() {
let policy = ExitPolicy::profit_target(dec!(0.5));
assert_eq!(policy, ExitPolicy::ProfitPercent(dec!(0.5)));
}
#[test]
fn test_stop_loss_creation() {
let policy = ExitPolicy::stop_loss(dec!(1.0));
assert_eq!(policy, ExitPolicy::LossPercent(dec!(1.0)));
}
#[test]
fn test_profit_or_loss_creation() {
let policy = ExitPolicy::profit_or_loss(dec!(0.5), dec!(1.0));
match policy {
ExitPolicy::Or(policies) => {
assert_eq!(policies.len(), 2);
assert_eq!(policies[0], ExitPolicy::ProfitPercent(dec!(0.5)));
assert_eq!(policies[1], ExitPolicy::LossPercent(dec!(1.0)));
}
_ => panic!("Expected Or variant"),
}
}
#[test]
fn test_profit_or_time_creation() {
let policy = ExitPolicy::profit_or_time(dec!(0.5), 1000);
match policy {
ExitPolicy::Or(policies) => {
assert_eq!(policies.len(), 2);
assert_eq!(policies[0], ExitPolicy::ProfitPercent(dec!(0.5)));
assert_eq!(policies[1], ExitPolicy::TimeSteps(1000));
}
_ => panic!("Expected Or variant"),
}
}
#[test]
fn test_is_composite() {
let simple = ExitPolicy::ProfitPercent(dec!(0.5));
assert!(!simple.is_composite());
let composite = ExitPolicy::Or(vec![
ExitPolicy::ProfitPercent(dec!(0.5)),
ExitPolicy::LossPercent(dec!(1.0)),
]);
assert!(composite.is_composite());
}
#[test]
fn test_condition_count() {
let simple = ExitPolicy::ProfitPercent(dec!(0.5));
assert_eq!(simple.condition_count(), 1);
let composite = ExitPolicy::Or(vec![
ExitPolicy::ProfitPercent(dec!(0.5)),
ExitPolicy::LossPercent(dec!(1.0)),
]);
assert_eq!(composite.condition_count(), 2);
let nested = ExitPolicy::And(vec![
ExitPolicy::Or(vec![
ExitPolicy::ProfitPercent(dec!(0.5)),
ExitPolicy::LossPercent(dec!(1.0)),
]),
ExitPolicy::TimeSteps(1000),
]);
assert_eq!(nested.condition_count(), 3);
}
#[test]
fn test_display_profit_percent() {
let policy = ExitPolicy::ProfitPercent(dec!(0.5));
assert_eq!(format!("{policy}"), "Profit Target: 50.0%");
}
#[test]
fn test_display_loss_percent() {
let policy = ExitPolicy::LossPercent(dec!(1.0));
assert_eq!(format!("{policy}"), "Stop Loss: 100.0%");
}
#[test]
fn test_display_fixed_price() {
let policy = ExitPolicy::FixedPrice(pos_or_panic!(50.0));
assert_eq!(format!("{policy}"), "Fixed Price: $50");
}
#[test]
fn test_display_time_steps() {
let policy = ExitPolicy::TimeSteps(1000);
assert_eq!(format!("{policy}"), "Time Steps: 1000");
}
#[test]
fn test_display_expiration() {
let policy = ExitPolicy::Expiration;
assert_eq!(format!("{policy}"), "Hold to Expiration");
}
#[test]
fn test_display_or_composite() {
let policy = ExitPolicy::Or(vec![
ExitPolicy::ProfitPercent(dec!(0.5)),
ExitPolicy::LossPercent(dec!(1.0)),
]);
let display = format!("{policy}");
assert!(display.contains("OR("));
assert!(display.contains("Profit Target: 50.0%"));
assert!(display.contains("Stop Loss: 100.0%"));
}
#[test]
fn test_display_and_composite() {
let policy = ExitPolicy::And(vec![
ExitPolicy::ProfitPercent(dec!(0.5)),
ExitPolicy::TimeSteps(1000),
]);
let display = format!("{policy}");
assert!(display.contains("AND("));
assert!(display.contains("Profit Target: 50.0%"));
assert!(display.contains("Time Steps: 1000"));
}
#[test]
fn test_serialization() {
let policy = ExitPolicy::profit_or_loss(dec!(0.5), dec!(1.0));
let json = serde_json::to_string(&policy).unwrap();
assert!(json.contains("Or"));
assert!(json.contains("ProfitPercent"));
assert!(json.contains("LossPercent"));
}
#[test]
fn test_deserialization() {
let json = r#"{"ProfitPercent":"0.5"}"#;
let policy: ExitPolicy = serde_json::from_str(json).unwrap();
assert_eq!(policy, ExitPolicy::ProfitPercent(dec!(0.5)));
}
#[test]
fn test_all_variants() {
let policies = vec![
ExitPolicy::ProfitPercent(dec!(0.5)),
ExitPolicy::LossPercent(dec!(1.0)),
ExitPolicy::FixedPrice(pos_or_panic!(50.0)),
ExitPolicy::MinPrice(pos_or_panic!(5.0)),
ExitPolicy::MaxPrice(Positive::HUNDRED),
ExitPolicy::TimeSteps(1000),
ExitPolicy::DaysToExpiration(Positive::TWO),
ExitPolicy::DeltaThreshold(dec!(0.5)),
ExitPolicy::UnderlyingPrice(pos_or_panic!(4000.0)),
ExitPolicy::UnderlyingBelow(pos_or_panic!(3900.0)),
ExitPolicy::UnderlyingAbove(pos_or_panic!(4100.0)),
ExitPolicy::Expiration,
];
for policy in policies {
let _ = format!("{policy}");
let _ = policy.clone();
}
}
}