use polars::prelude::PlSmallStr;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumCount, EnumIter, EnumString, IntoStaticStr};
use crate::{
data::domain::{Instrument, Price, Quantity, Symbol},
error::{AgentError, ChapatyResult},
};
#[derive(
Copy,
Clone,
Debug,
EnumString,
Display,
PartialEq,
Eq,
Hash,
Deserialize,
Serialize,
PartialOrd,
Ord,
IntoStaticStr,
)]
#[strum(serialize_all = "lowercase")]
pub enum TradeType {
Long,
Short,
}
impl From<TradeType> for PlSmallStr {
fn from(value: TradeType) -> Self {
value.as_str().into()
}
}
impl TradeType {
pub fn name(&self) -> PlSmallStr {
(*self).into()
}
pub fn as_str(&self) -> &'static str {
self.into()
}
}
impl TradeType {
pub fn price_ordering_validation(
&self,
stop_loss: Option<Price>,
entry: Option<Price>,
take_profit: Option<Price>,
) -> ChapatyResult<()> {
fn err_msg(msg: &str) -> ChapatyResult<()> {
Err(AgentError::InvalidInput(msg.to_string()).into())
}
use TradeType::*;
match (self, stop_loss, entry, take_profit) {
(_, None, None, None) => Ok(()),
(_, None, None, Some(_)) | (_, None, Some(_), None) | (_, Some(_), None, None) => {
Ok(())
}
(Long, None, Some(en), Some(tp)) if en.0 < tp.0 => Ok(()),
(Short, None, Some(en), Some(tp)) if tp.0 < en.0 => Ok(()),
(Long, Some(sl), None, Some(tp)) if sl.0 < tp.0 => Ok(()),
(Short, Some(sl), None, Some(tp)) if tp.0 < sl.0 => Ok(()),
(Long, Some(sl), Some(en), None) if sl.0 < en.0 => Ok(()),
(Short, Some(sl), Some(en), None) if en.0 < sl.0 => Ok(()),
(Long, Some(sl), Some(en), Some(tp)) if sl.0 < en.0 && en.0 < tp.0 => Ok(()),
(Short, Some(sl), Some(en), Some(tp)) if tp.0 < en.0 && en.0 < sl.0 => Ok(()),
(Long, Some(_), Some(_), Some(_)) => {
err_msg("For long trades: stop_loss < entry < take_profit")
}
(Short, Some(_), Some(_), Some(_)) => {
err_msg("For short trades: take_profit < entry < stop_loss")
}
(Long, None, Some(_), Some(_)) => err_msg("For long trades: entry < take_profit"),
(Short, None, Some(_), Some(_)) => err_msg("For short trades: take_profit < entry"),
(Long, Some(_), None, Some(_)) => err_msg("For long trades: stop_loss < take_profit"),
(Short, Some(_), None, Some(_)) => err_msg("For short trades: take_profit < stop_loss"),
(Long, Some(_), Some(_), None) => err_msg("For long trades: stop_loss < entry"),
(Short, Some(_), Some(_), None) => err_msg("For short trades: entry < stop_loss"),
}
}
pub fn price_diff(&self, entry: Price, exit: Price) -> Price {
match self {
TradeType::Long => exit - entry,
TradeType::Short => entry - exit,
}
}
pub fn calculate_pnl(&self, entry: Price, exit: Price, qty: Quantity, symbol: &Symbol) -> f64 {
let price_dist = self.price_diff(entry, exit);
let ticks = symbol.price_to_ticks(price_dist);
let unit_pnl = symbol.ticks_to_usd(ticks);
unit_pnl * qty.0
}
}
#[derive(
Copy,
Clone,
Debug,
EnumString,
Display,
PartialEq,
Eq,
Hash,
Deserialize,
Serialize,
PartialOrd,
Ord,
IntoStaticStr,
)]
#[strum(serialize_all = "snake_case")]
pub enum TerminationReason {
StopLoss,
TakeProfit,
MarketClose,
Canceled,
}
impl From<TerminationReason> for PlSmallStr {
fn from(value: TerminationReason) -> Self {
value.as_str().into()
}
}
impl TerminationReason {
pub fn name(&self) -> PlSmallStr {
(*self).into()
}
pub fn as_str(&self) -> &'static str {
self.into()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RiskRewardRatio {
risk: f64,
reward: f64,
ratio: f64,
}
impl RiskRewardRatio {
pub fn new(risk_usd: f64, reward_usd: f64) -> Self {
let risk = risk_usd.abs();
let reward = reward_usd.abs();
let ratio = if reward == 0.0 {
if risk == 0.0 {
0.0 } else {
f64::INFINITY }
} else if risk == 0.0 {
0.0 } else {
risk / reward
};
Self {
risk,
reward,
ratio,
}
}
pub fn ratio(&self) -> f64 {
self.ratio
}
}
impl std::fmt::Display for RiskRewardRatio {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}:1", self.ratio)
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
PartialOrd,
Eq,
Hash,
Serialize,
Deserialize,
EnumString,
EnumIter,
EnumCount,
Display,
IntoStaticStr,
)]
#[strum(serialize_all = "lowercase")]
pub enum StateKind {
Active,
Closed,
Pending,
Canceled,
}
impl From<StateKind> for PlSmallStr {
fn from(value: StateKind) -> Self {
value.as_str().into()
}
}
impl StateKind {
pub fn name(&self) -> PlSmallStr {
(*self).into()
}
pub fn as_str(&self) -> &'static str {
self.into()
}
}
#[cfg(test)]
mod tests {
use crate::data::domain::{ContractMonth, ContractYear, FutureContract, FutureRoot};
use super::*;
fn sl(v: f64) -> Price {
Price(v)
}
fn en(v: f64) -> Price {
Price(v)
}
fn tp(v: f64) -> Price {
Price(v)
}
#[test]
fn test_long_valid_cases() {
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(90.0)), Some(en(100.0)), Some(tp(110.0)))
.is_ok()
);
assert!(
TradeType::Long
.price_ordering_validation(None, Some(en(100.0)), Some(tp(120.0)))
.is_ok()
);
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(80.0)), Some(en(100.0)), None)
.is_ok()
);
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(80.0)), None, Some(tp(120.0)))
.is_ok()
);
}
#[test]
fn test_long_invalid_cases() {
assert!(
TradeType::Long
.price_ordering_validation(None, Some(en(120.0)), Some(tp(100.0)))
.is_err()
);
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(100.0)), Some(en(90.0)), None)
.is_err()
);
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(120.0)), None, Some(tp(100.0)))
.is_err()
);
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(100.0)), Some(en(110.0)), Some(tp(105.0)))
.is_err()
);
}
#[test]
fn test_short_valid_cases() {
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(120.0)), Some(en(110.0)), Some(tp(100.0)))
.is_ok()
);
assert!(
TradeType::Short
.price_ordering_validation(None, Some(en(110.0)), Some(tp(100.0)))
.is_ok()
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(120.0)), Some(en(100.0)), None)
.is_ok()
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(120.0)), None, Some(tp(100.0)))
.is_ok()
);
}
#[test]
fn test_short_invalid_cases() {
assert!(
TradeType::Short
.price_ordering_validation(None, Some(en(90.0)), Some(tp(110.0)))
.is_err()
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(90.0)), Some(en(100.0)), None)
.is_err()
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(100.0)), None, Some(tp(110.0)))
.is_err()
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(100.0)), Some(en(90.0)), Some(tp(95.0)))
.is_err()
);
}
#[test]
fn test_three_legged_invariants() {
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(105.0)), Some(en(100.0)), Some(tp(110.0)))
.is_err(),
"Long: Valid Entry/TP should not mask invalid SL/Entry"
);
assert!(
TradeType::Long
.price_ordering_validation(Some(sl(90.0)), Some(en(100.0)), Some(tp(95.0)))
.is_err(),
"Long: Valid SL/Entry should not mask invalid Entry/TP"
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(95.0)), Some(en(100.0)), Some(tp(90.0)))
.is_err(),
"Short: Valid TP/Entry should not mask invalid Entry/SL"
);
assert!(
TradeType::Short
.price_ordering_validation(Some(sl(110.0)), Some(en(100.0)), Some(tp(105.0)))
.is_err(),
"Short: Valid Entry/SL should not mask invalid TP/Entry"
);
}
#[test]
fn test_price_diff_logic() {
assert_eq!(
TradeType::Long.price_diff(Price(100.0), Price(110.0)),
Price(10.0),
"Long should be positive when price goes up"
);
assert_eq!(
TradeType::Long.price_diff(Price(110.0), Price(100.0)),
Price(-10.0),
"Long should be negative when price goes down"
);
assert_eq!(
TradeType::Short.price_diff(Price(110.0), Price(100.0)),
Price(10.0),
"Short should be positive when price goes down"
);
assert_eq!(
TradeType::Short.price_diff(Price(100.0), Price(110.0)),
Price(-10.0),
"Short should be negative when price goes up"
);
}
#[test]
fn test_pnl_cleans_dirty_inputs() {
let eur = Symbol::Future(FutureContract {
root: FutureRoot::EurUsd,
month: ContractMonth::December,
year: ContractYear::Y5,
});
let entry = Price(1.10000);
let dirty_exit = Price(1.10050001);
let pnl = TradeType::Long.calculate_pnl(entry, dirty_exit, Quantity(1.0), &eur);
assert_eq!(pnl, 62.5, "Long PnL failed to snap dirty input to grid");
let dirty_entry = Price(1.09949999); let clean_exit = Price(1.09900);
let pnl_short =
TradeType::Short.calculate_pnl(dirty_entry, clean_exit, Quantity(1.0), &eur);
assert_eq!(
pnl_short, 62.5,
"Short PnL failed to snap dirty input to grid"
);
}
#[test]
fn calculates_standard_ratios() {
let favorable = RiskRewardRatio::new(50.0, 100.0);
assert_eq!(favorable.ratio(), 0.5);
let unfavorable = RiskRewardRatio::new(100.0, 50.0);
assert_eq!(unfavorable.ratio(), 2.0);
let neutral = RiskRewardRatio::new(100.0, 100.0);
assert_eq!(neutral.ratio(), 1.0);
}
#[test]
fn handles_negative_inputs_gracefully() {
let rrr = RiskRewardRatio::new(-50.0, -100.0);
assert_eq!(rrr.risk, 50.0);
assert_eq!(rrr.reward, 100.0);
assert_eq!(rrr.ratio(), 0.5);
}
#[test]
fn handles_edge_cases() {
let zero_risk = RiskRewardRatio::new(0.0, 100.0);
assert_eq!(zero_risk.ratio(), 0.0);
let zero_reward = RiskRewardRatio::new(100.0, 0.0);
assert_eq!(zero_reward.ratio(), f64::INFINITY);
let zero_zero = RiskRewardRatio::new(0.0, 0.0);
assert_eq!(zero_zero.ratio(), 0.0);
}
#[test]
fn test_display_formatting() {
let rrr = RiskRewardRatio::new(50.0, 100.0);
assert_eq!(rrr.to_string(), "0.50:1");
let rrr_whole = RiskRewardRatio::new(200.0, 100.0);
assert_eq!(rrr_whole.to_string(), "2.00:1");
let rrr_inf = RiskRewardRatio::new(100.0, 0.0);
assert_eq!(rrr_inf.to_string(), "inf:1");
}
}